-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][RemoveDeadValues] Mark arguments of a public function Live #162038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This diff also changes traversal order from forward to backward for region/block/ops. This order guanratees Liveness updates at a callsite can propagates to the defs of arguments. ``` ./bin/llvm-lit -v ../mlir/test/Transforms/remove-dead-values.mlir ```
@llvm/pr-subscribers-mlir Author: xin liu (navyxliu) ChangesThis diff is REDO of #160242. ProblemLiveness analysis is inter-procedural. If there are some unused arguments in a public function, they propagate to callers. From the perspective of RemoveDeadValues, the signature of a public function is immutable. It can't cope with this situation. One side, it deletes outgoing arguments, on the other side it keeps the function intact. SolutionWe deploy two methods to fix this bug.
Test plan
Full diff: https://github.com/llvm/llvm-project/pull/162038.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index cf1fd6e2d48ca..be7e027b95f64 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -102,6 +102,9 @@ struct RunLivenessAnalysis {
const Liveness *getLiveness(Value val);
+ /// Return the configuration of the solver used for this analysis.
+ const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
+
private:
/// Stores the result of the liveness analysis that was run.
DataFlowSolver solver;
diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 893f66ae33deb..5766d262796d6 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -39,6 +39,20 @@ struct ForwardIterator {
}
};
+/// This iterator enumerates the elements in "backward" order.
+struct BackwardIterator {
+ template <typename T>
+ static auto makeIterable(T &range) {
+ if constexpr (std::is_same<T, Operation>()) {
+ /// Make operations iterable: return the list of regions.
+ return llvm::reverse(range.getRegions());
+ } else {
+ /// Regions and block are already iterable.
+ return llvm::reverse(range);
+ }
+ }
+};
+
/// A utility class to encode the current walk stage for "generic" walkers.
/// When walking an operation, we can either choose a Pre/Post order walker
/// which invokes the callback on an operation before/after all its attached
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index e0c65b0e09774..ed5c6a8d2ead0 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -33,6 +33,7 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
@@ -118,8 +119,13 @@ struct RDVFinalCleanupList {
/// Return true iff at least one value in `values` is live, given the liveness
/// information in `la`.
static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
- RunLivenessAnalysis &la) {
+ const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
for (Value value : values) {
+ if (liveSet.contains(value)) {
+ LDBG() << "Value " << value << " is marked live by CallOp";
+ return true;
+ }
+
if (nonLiveSet.contains(value)) {
LDBG() << "Value " << value << " is already marked non-live (dead)";
continue;
@@ -144,6 +150,7 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
/// i-th value in `values` is live, given the liveness information in `la`.
static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
+ const DenseSet<Value> &liveSet,
RunLivenessAnalysis &la) {
BitVector lives(values.size(), true);
@@ -154,7 +161,9 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
<< " is already marked non-live (dead) at index " << index;
continue;
}
-
+ if (liveSet.contains(value)) {
+ continue;
+ }
const Liveness *liveness = la.getLiveness(value);
// It is important to note that when `liveness` is null, we can't tell if
// `value` is live or not. So, the safe option is to consider it live. Also,
@@ -259,8 +268,19 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
/// - Return-like
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
- RDVFinalCleanupList &cl) {
- if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
+ DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
+ for (Value val : op->getResults()) {
+ if (liveSet.contains(val)) {
+ LDBG() << "Simple op is used by a public function, "
+ "preserving it: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ liveSet.insert_range(op->getOperands());
+ return;
+ }
+ }
+
+ if (!isMemoryEffectFree(op) ||
+ hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
LDBG() << "Simple op is not memory effect free or has live results, "
"preserving it: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
@@ -288,7 +308,7 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
/// (6) Marking all its results as non-live values.
static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
- RDVFinalCleanupList &cl) {
+ DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
LDBG() << "Processing function op: "
<< OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
if (funcOp.isPublic() || funcOp.isExternal()) {
@@ -299,7 +319,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
// Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
SmallVector<Value> arguments(funcOp.getArguments());
- BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
+ BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
nonLiveArgs = nonLiveArgs.flip();
// Do (1).
@@ -352,7 +372,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
- BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
+ BitVector liveCallRets =
+ markLives(callOp->getResults(), nonLiveSet, liveSet, la);
nonLiveRets &= liveCallRets.flip();
}
@@ -379,6 +400,56 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
}
}
+// Create a cheaper value with the same type of oldVal in front of CallOp.
+static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
+ OpBuilder builder(callOp.getOperation());
+ Type type = oldVal.getType();
+
+ // Create zero constant for any supported type
+ if (TypedAttr zeroAttr = builder.getZeroAttr(type)) {
+ return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
+ }
+ return {};
+}
+
+static void processCallOp(CallOpInterface callOp, Operation *module,
+ RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
+ DenseSet<Value> &liveSet) {
+ if (!la.getSolverConfig().isInterprocedural())
+ return;
+
+ Operation *callableOp = callOp.resolveCallable();
+ auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
+ if (!funcOp || !funcOp.isPublic()) {
+ return;
+ }
+
+ LDBG() << "processCallOp to a public function: " << funcOp.getName();
+ // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
+ SmallVector<Value> arguments(funcOp.getArguments());
+ BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
+ nonLiveArgs = nonLiveArgs.flip();
+
+ if (nonLiveArgs.count() > 0) {
+ LDBG() << funcOp.getName() << " contains NonLive arguments";
+ // The number of operands in the call op may not match the number of
+ // arguments in the func op.
+ SmallVector<OpOperand *> callOpOperands =
+ operandsToOpOperands(callOp.getArgOperands());
+
+ for (int index : nonLiveArgs.set_bits()) {
+ OpOperand *operand = callOpOperands[index];
+ Value oldVal = operand->get();
+ if (Value dummy = createDummyArgument(callOp, oldVal)) {
+ callOp->setOperand(operand->getOperandNumber(), dummy);
+ nonLiveSet.insert(oldVal);
+ } else {
+ liveSet.insert(oldVal);
+ }
+ }
+ }
+}
+
/// Process a region branch operation `regionBranchOp` using the liveness
/// information in `la`. The processing involves two scenarios:
///
@@ -411,12 +482,14 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
+ DenseSet<Value> &liveSet,
RDVFinalCleanupList &cl) {
LDBG() << "Processing region branch op: "
<< OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
// Mark live results of `regionBranchOp` in `liveResults`.
auto markLiveResults = [&](BitVector &liveResults) {
- liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
+ liveResults =
+ markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
};
// Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
@@ -425,7 +498,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
SmallVector<Value> arguments(region.front().getArguments());
- BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
+ BitVector regionLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
liveArgs[®ion] = regionLiveArgs;
}
};
@@ -619,7 +692,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// attributed to something else.
// Do (1') and (2').
if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
- !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
+ !hasLive(regionBranchOp->getResults(), nonLiveSet, liveSet, la)) {
cl.operations.push_back(regionBranchOp.getOperation());
return;
}
@@ -698,7 +771,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
- RDVFinalCleanupList &cl) {
+ DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
LDBG() << "Processing branch op: " << *branchOp;
unsigned numSuccessors = branchOp->getNumSuccessors();
@@ -716,7 +789,7 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
// Do (2)
BitVector successorNonLive =
- markLives(operandValues, nonLiveSet, la).flip();
+ markLives(operandValues, nonLiveSet, liveSet, la).flip();
collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
successorNonLive);
@@ -876,26 +949,29 @@ void RemoveDeadValues::runOnOperation() {
// Tracks values eligible for erasure - complements liveness analysis to
// identify "droppable" values.
DenseSet<Value> deadVals;
+ // mark outgoing arguments to a public function LIVE. We also propagate
+ // liveness backward.
+ DenseSet<Value> liveVals;
// Maintains a list of Ops, values, branches, etc., slated for cleanup at the
// end of this pass.
RDVFinalCleanupList finalCleanupList;
- module->walk([&](Operation *op) {
+ module->walk<WalkOrder::PostOrder, BackwardIterator>([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
- processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
+ processFuncOp(funcOp, module, la, deadVals, liveVals, finalCleanupList);
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
- processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
+ processRegionBranchOp(regionBranchOp, la, deadVals, liveVals,
+ finalCleanupList);
} else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
- processBranchOp(branchOp, la, deadVals, finalCleanupList);
+ processBranchOp(branchOp, la, deadVals, liveVals, finalCleanupList);
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
// Nothing to do here because this is a terminator op and it should be
// honored with respect to its parent
} else if (isa<CallOpInterface>(op)) {
- // Nothing to do because this op is associated with a function op and gets
- // cleaned when the latter is cleaned.
+ processCallOp(cast<CallOpInterface>(op), module, la, deadVals, liveVals);
} else {
- processSimpleOp(op, la, deadVals, finalCleanupList);
+ processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);
}
});
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 56449469dc29f..ebb76a8835ceb 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -569,6 +569,24 @@ module @return_void_with_unused_argument {
call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> ()
return %unused : memref<4xi32>
}
+
+ // the function signature is immutable because it is public.
+ func.func public @immutable_fn_with_unused_argument(%arg0: i32, %arg1: memref<4xf32>) -> () {
+ return
+ }
+
+ // CHECK-LABEL: func.func @main2
+ // CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
+ // CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32
+ // CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
+ func.func @main2() -> () {
+ %one = arith.constant 1 : i32
+ %scalar = arith.addi %one, %one: i32
+ %mem = memref.alloc() : memref<4xf32>
+
+ call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> ()
+ return
+ }
}
// -----
|
@llvm/pr-subscribers-mlir-core Author: xin liu (navyxliu) ChangesThis diff is REDO of #160242. ProblemLiveness analysis is inter-procedural. If there are some unused arguments in a public function, they propagate to callers. From the perspective of RemoveDeadValues, the signature of a public function is immutable. It can't cope with this situation. One side, it deletes outgoing arguments, on the other side it keeps the function intact. SolutionWe deploy two methods to fix this bug.
Test plan
Full diff: https://github.com/llvm/llvm-project/pull/162038.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index cf1fd6e2d48ca..be7e027b95f64 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -102,6 +102,9 @@ struct RunLivenessAnalysis {
const Liveness *getLiveness(Value val);
+ /// Return the configuration of the solver used for this analysis.
+ const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
+
private:
/// Stores the result of the liveness analysis that was run.
DataFlowSolver solver;
diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 893f66ae33deb..5766d262796d6 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -39,6 +39,20 @@ struct ForwardIterator {
}
};
+/// This iterator enumerates the elements in "backward" order.
+struct BackwardIterator {
+ template <typename T>
+ static auto makeIterable(T &range) {
+ if constexpr (std::is_same<T, Operation>()) {
+ /// Make operations iterable: return the list of regions.
+ return llvm::reverse(range.getRegions());
+ } else {
+ /// Regions and block are already iterable.
+ return llvm::reverse(range);
+ }
+ }
+};
+
/// A utility class to encode the current walk stage for "generic" walkers.
/// When walking an operation, we can either choose a Pre/Post order walker
/// which invokes the callback on an operation before/after all its attached
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index e0c65b0e09774..ed5c6a8d2ead0 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -33,6 +33,7 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
@@ -118,8 +119,13 @@ struct RDVFinalCleanupList {
/// Return true iff at least one value in `values` is live, given the liveness
/// information in `la`.
static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
- RunLivenessAnalysis &la) {
+ const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
for (Value value : values) {
+ if (liveSet.contains(value)) {
+ LDBG() << "Value " << value << " is marked live by CallOp";
+ return true;
+ }
+
if (nonLiveSet.contains(value)) {
LDBG() << "Value " << value << " is already marked non-live (dead)";
continue;
@@ -144,6 +150,7 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
/// i-th value in `values` is live, given the liveness information in `la`.
static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
+ const DenseSet<Value> &liveSet,
RunLivenessAnalysis &la) {
BitVector lives(values.size(), true);
@@ -154,7 +161,9 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
<< " is already marked non-live (dead) at index " << index;
continue;
}
-
+ if (liveSet.contains(value)) {
+ continue;
+ }
const Liveness *liveness = la.getLiveness(value);
// It is important to note that when `liveness` is null, we can't tell if
// `value` is live or not. So, the safe option is to consider it live. Also,
@@ -259,8 +268,19 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
/// - Return-like
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
- RDVFinalCleanupList &cl) {
- if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
+ DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
+ for (Value val : op->getResults()) {
+ if (liveSet.contains(val)) {
+ LDBG() << "Simple op is used by a public function, "
+ "preserving it: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ liveSet.insert_range(op->getOperands());
+ return;
+ }
+ }
+
+ if (!isMemoryEffectFree(op) ||
+ hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
LDBG() << "Simple op is not memory effect free or has live results, "
"preserving it: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
@@ -288,7 +308,7 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
/// (6) Marking all its results as non-live values.
static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
- RDVFinalCleanupList &cl) {
+ DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
LDBG() << "Processing function op: "
<< OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
if (funcOp.isPublic() || funcOp.isExternal()) {
@@ -299,7 +319,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
// Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
SmallVector<Value> arguments(funcOp.getArguments());
- BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
+ BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
nonLiveArgs = nonLiveArgs.flip();
// Do (1).
@@ -352,7 +372,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
- BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
+ BitVector liveCallRets =
+ markLives(callOp->getResults(), nonLiveSet, liveSet, la);
nonLiveRets &= liveCallRets.flip();
}
@@ -379,6 +400,56 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
}
}
+// Create a cheaper value with the same type of oldVal in front of CallOp.
+static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
+ OpBuilder builder(callOp.getOperation());
+ Type type = oldVal.getType();
+
+ // Create zero constant for any supported type
+ if (TypedAttr zeroAttr = builder.getZeroAttr(type)) {
+ return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
+ }
+ return {};
+}
+
+static void processCallOp(CallOpInterface callOp, Operation *module,
+ RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
+ DenseSet<Value> &liveSet) {
+ if (!la.getSolverConfig().isInterprocedural())
+ return;
+
+ Operation *callableOp = callOp.resolveCallable();
+ auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
+ if (!funcOp || !funcOp.isPublic()) {
+ return;
+ }
+
+ LDBG() << "processCallOp to a public function: " << funcOp.getName();
+ // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
+ SmallVector<Value> arguments(funcOp.getArguments());
+ BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
+ nonLiveArgs = nonLiveArgs.flip();
+
+ if (nonLiveArgs.count() > 0) {
+ LDBG() << funcOp.getName() << " contains NonLive arguments";
+ // The number of operands in the call op may not match the number of
+ // arguments in the func op.
+ SmallVector<OpOperand *> callOpOperands =
+ operandsToOpOperands(callOp.getArgOperands());
+
+ for (int index : nonLiveArgs.set_bits()) {
+ OpOperand *operand = callOpOperands[index];
+ Value oldVal = operand->get();
+ if (Value dummy = createDummyArgument(callOp, oldVal)) {
+ callOp->setOperand(operand->getOperandNumber(), dummy);
+ nonLiveSet.insert(oldVal);
+ } else {
+ liveSet.insert(oldVal);
+ }
+ }
+ }
+}
+
/// Process a region branch operation `regionBranchOp` using the liveness
/// information in `la`. The processing involves two scenarios:
///
@@ -411,12 +482,14 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
+ DenseSet<Value> &liveSet,
RDVFinalCleanupList &cl) {
LDBG() << "Processing region branch op: "
<< OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
// Mark live results of `regionBranchOp` in `liveResults`.
auto markLiveResults = [&](BitVector &liveResults) {
- liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
+ liveResults =
+ markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
};
// Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
@@ -425,7 +498,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
SmallVector<Value> arguments(region.front().getArguments());
- BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
+ BitVector regionLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
liveArgs[®ion] = regionLiveArgs;
}
};
@@ -619,7 +692,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// attributed to something else.
// Do (1') and (2').
if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
- !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
+ !hasLive(regionBranchOp->getResults(), nonLiveSet, liveSet, la)) {
cl.operations.push_back(regionBranchOp.getOperation());
return;
}
@@ -698,7 +771,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
- RDVFinalCleanupList &cl) {
+ DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
LDBG() << "Processing branch op: " << *branchOp;
unsigned numSuccessors = branchOp->getNumSuccessors();
@@ -716,7 +789,7 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
// Do (2)
BitVector successorNonLive =
- markLives(operandValues, nonLiveSet, la).flip();
+ markLives(operandValues, nonLiveSet, liveSet, la).flip();
collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
successorNonLive);
@@ -876,26 +949,29 @@ void RemoveDeadValues::runOnOperation() {
// Tracks values eligible for erasure - complements liveness analysis to
// identify "droppable" values.
DenseSet<Value> deadVals;
+ // mark outgoing arguments to a public function LIVE. We also propagate
+ // liveness backward.
+ DenseSet<Value> liveVals;
// Maintains a list of Ops, values, branches, etc., slated for cleanup at the
// end of this pass.
RDVFinalCleanupList finalCleanupList;
- module->walk([&](Operation *op) {
+ module->walk<WalkOrder::PostOrder, BackwardIterator>([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
- processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
+ processFuncOp(funcOp, module, la, deadVals, liveVals, finalCleanupList);
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
- processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
+ processRegionBranchOp(regionBranchOp, la, deadVals, liveVals,
+ finalCleanupList);
} else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
- processBranchOp(branchOp, la, deadVals, finalCleanupList);
+ processBranchOp(branchOp, la, deadVals, liveVals, finalCleanupList);
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
// Nothing to do here because this is a terminator op and it should be
// honored with respect to its parent
} else if (isa<CallOpInterface>(op)) {
- // Nothing to do because this op is associated with a function op and gets
- // cleaned when the latter is cleaned.
+ processCallOp(cast<CallOpInterface>(op), module, la, deadVals, liveVals);
} else {
- processSimpleOp(op, la, deadVals, finalCleanupList);
+ processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);
}
});
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 56449469dc29f..ebb76a8835ceb 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -569,6 +569,24 @@ module @return_void_with_unused_argument {
call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> ()
return %unused : memref<4xi32>
}
+
+ // the function signature is immutable because it is public.
+ func.func public @immutable_fn_with_unused_argument(%arg0: i32, %arg1: memref<4xf32>) -> () {
+ return
+ }
+
+ // CHECK-LABEL: func.func @main2
+ // CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
+ // CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32
+ // CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
+ func.func @main2() -> () {
+ %one = arith.constant 1 : i32
+ %scalar = arith.addi %one, %one: i32
+ %mem = memref.alloc() : memref<4xf32>
+
+ call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> ()
+ return
+ }
}
// -----
|
hi, @joker-eph |
call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> () | ||
return | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add two CFG examples where the blocks are listed in different order to ensure you're not sensitive to the order the blocks are in-memory.
RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, | ||
DenseSet<Value> &liveSet) { | ||
if (!la.getSolverConfig().isInterprocedural()) | ||
return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That deserves a comment.
return {}; | ||
} | ||
|
||
static void processCallOp(CallOpInterface callOp, Operation *module, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you document the API?
if (!funcOp || !funcOp.isPublic()) { | ||
return; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (!funcOp || !funcOp.isPublic()) { | |
return; | |
} | |
if (!funcOp || !funcOp.isPublic()) | |
return; |
Nit: no trivial braces.
return llvm::reverse(range); | ||
} | ||
} | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which do you need to do this? Are you somehow trying to make an assumption about the order in which the regions are executed at runtime?
The order of the region on the op are not indicative of anything related to this.
continue; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
continue; | |
} | |
LDBG() << "Value " << value | |
<< " is already marked live at index " << index; | |
continue; | |
} |
Should we align the logging with above?
(otherwise remove the braces)
return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr); | ||
} | ||
return {}; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we split this out in a follow-up PR: keep this PR about fixing the bugs without introducing an "aggressive" optimization, and introduce the optimization on its own afterward?
This diff is REDO of #160242.
Integrated feedbacks from the previous feedbacks and also try to substitute arguments with dummy values first.
Problem
Liveness analysis is inter-procedural. If there are some unused arguments in a public function, they propagate to callers. From the perspective of RemoveDeadValues, the signature of a public function is immutable. It can't cope with this situation. One side, it deletes outgoing arguments, on the other side it keeps the function intact.
Solution
We deploy two methods to fix this bug.
createDummyArgument replaces arguments with dummies. eg. If the type can create a value with zero initializer, such as 'T x{0}'.
As a fallback, we add another DenseSet called 'liveSet'. The initial values are the arguments of public functions. It propagates liveness backward just like Liveness analysis.
Test plan