Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ struct RunLivenessAnalysis {

const Liveness *getLiveness(Value val);

/// Return the configuration of the solver used for this analysis.
const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }

private:
/// Stores the result of the liveness analysis that was run.
DataFlowSolver solver;
Expand Down
14 changes: 14 additions & 0 deletions mlir/include/mlir/IR/Visitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ struct ForwardIterator {
}
};

/// This iterator enumerates the elements in "backward" order.
struct BackwardIterator {
template <typename T>
static auto makeIterable(T &range) {
if constexpr (std::is_same<T, Operation>()) {
/// Make operations iterable: return the list of regions.
return llvm::reverse(range.getRegions());
} else {
/// Regions and block are already iterable.
return llvm::reverse(range);
}
}
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which do you need to do this? Are you somehow trying to make an assumption about the order in which the regions are executed at runtime?
The order of the region on the op are not indicative of anything related to this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, @joker-eph,

Which do you need to do this?
I change iterator here. My intention is to visit basic block and CFG in reverse order.
module->walk<WalkOrder::PostOrder, BackwardIterator>

The order of the region on the op are not indicative of anything related to this.
TBH, this is the part of MLIR I don't understand. Should we also reverse Operation::getRegions()? yes, my understanding is that the order doesn't matter for regions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intention is to visit basic block and CFG in reverse order.

Reverse order from what? The order of basic blocks isn't indicative of any execution order. The only constraint is on the first basic block to be the "entry" into the region.


/// A utility class to encode the current walk stage for "generic" walkers.
/// When walking an operation, we can either choose a Pre/Post order walker
/// which invokes the callback on an operation before/after all its attached
Expand Down
114 changes: 95 additions & 19 deletions mlir/lib/Transforms/RemoveDeadValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
Expand Down Expand Up @@ -118,8 +119,13 @@ struct RDVFinalCleanupList {
/// Return true iff at least one value in `values` is live, given the liveness
/// information in `la`.
static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
RunLivenessAnalysis &la) {
const DenseSet<Value> &liveSet, RunLivenessAnalysis &la) {
for (Value value : values) {
if (liveSet.contains(value)) {
LDBG() << "Value " << value << " is marked live by CallOp";
return true;
}

if (nonLiveSet.contains(value)) {
LDBG() << "Value " << value << " is already marked non-live (dead)";
continue;
Expand All @@ -144,6 +150,7 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
/// i-th value in `values` is live, given the liveness information in `la`.
static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
const DenseSet<Value> &liveSet,
RunLivenessAnalysis &la) {
BitVector lives(values.size(), true);

Expand All @@ -154,7 +161,9 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
<< " is already marked non-live (dead) at index " << index;
continue;
}

if (liveSet.contains(value)) {
continue;
}
Comment on lines +165 to +166
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
continue;
}
LDBG() << "Value " << value
<< " is already marked live at index " << index;
continue;
}

Should we align the logging with above?
(otherwise remove the braces)

const Liveness *liveness = la.getLiveness(value);
// It is important to note that when `liveness` is null, we can't tell if
// `value` is live or not. So, the safe option is to consider it live. Also,
Expand Down Expand Up @@ -259,8 +268,19 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
/// - Return-like
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
for (Value val : op->getResults()) {
if (liveSet.contains(val)) {
LDBG() << "Simple op is used by a public function, "
"preserving it: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
liveSet.insert_range(op->getOperands());
return;
}
}

if (!isMemoryEffectFree(op) ||
hasLive(op->getResults(), nonLiveSet, liveSet, la)) {
LDBG() << "Simple op is not memory effect free or has live results, "
"preserving it: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
Expand Down Expand Up @@ -288,7 +308,7 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
/// (6) Marking all its results as non-live values.
static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
LDBG() << "Processing function op: "
<< OpWithFlags(funcOp, OpPrintingFlags().skipRegions());
if (funcOp.isPublic() || funcOp.isExternal()) {
Expand All @@ -299,7 +319,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,

// Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
SmallVector<Value> arguments(funcOp.getArguments());
BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
nonLiveArgs = nonLiveArgs.flip();

// Do (1).
Expand Down Expand Up @@ -352,7 +372,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
BitVector liveCallRets =
markLives(callOp->getResults(), nonLiveSet, liveSet, la);
nonLiveRets &= liveCallRets.flip();
}

Expand All @@ -379,6 +400,56 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
}
}

// Create a cheaper value with the same type of oldVal in front of CallOp.
static Value createDummyArgument(CallOpInterface callOp, Value oldVal) {
OpBuilder builder(callOp.getOperation());
Type type = oldVal.getType();

// Create zero constant for any supported type
if (TypedAttr zeroAttr = builder.getZeroAttr(type)) {
return builder.create<arith::ConstantOp>(oldVal.getLoc(), type, zeroAttr);
}
return {};
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we split this out in a follow-up PR: keep this PR about fixing the bugs without introducing an "aggressive" optimization, and introduce the optimization on its own afterward?


static void processCallOp(CallOpInterface callOp, Operation *module,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you document the API?

RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
DenseSet<Value> &liveSet) {
if (!la.getSolverConfig().isInterprocedural())
return;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That deserves a comment.


Operation *callableOp = callOp.resolveCallable();
auto funcOp = dyn_cast<FunctionOpInterface>(callableOp);
if (!funcOp || !funcOp.isPublic()) {
return;
}
Comment on lines +423 to +425
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!funcOp || !funcOp.isPublic()) {
return;
}
if (!funcOp || !funcOp.isPublic())
return;

Nit: no trivial braces.


LDBG() << "processCallOp to a public function: " << funcOp.getName();
// Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
SmallVector<Value> arguments(funcOp.getArguments());
BitVector nonLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
nonLiveArgs = nonLiveArgs.flip();

if (nonLiveArgs.count() > 0) {
LDBG() << funcOp.getName() << " contains NonLive arguments";
// The number of operands in the call op may not match the number of
// arguments in the func op.
SmallVector<OpOperand *> callOpOperands =
operandsToOpOperands(callOp.getArgOperands());

for (int index : nonLiveArgs.set_bits()) {
OpOperand *operand = callOpOperands[index];
Value oldVal = operand->get();
if (Value dummy = createDummyArgument(callOp, oldVal)) {
callOp->setOperand(operand->getOperandNumber(), dummy);
nonLiveSet.insert(oldVal);
} else {
liveSet.insert(oldVal);
}
}
}
}

/// Process a region branch operation `regionBranchOp` using the liveness
/// information in `la`. The processing involves two scenarios:
///
Expand Down Expand Up @@ -411,12 +482,14 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
DenseSet<Value> &liveSet,
RDVFinalCleanupList &cl) {
LDBG() << "Processing region branch op: "
<< OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
// Mark live results of `regionBranchOp` in `liveResults`.
auto markLiveResults = [&](BitVector &liveResults) {
liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
liveResults =
markLives(regionBranchOp->getResults(), nonLiveSet, liveSet, la);
};

// Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
Expand All @@ -425,7 +498,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
SmallVector<Value> arguments(region.front().getArguments());
BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
BitVector regionLiveArgs = markLives(arguments, nonLiveSet, liveSet, la);
liveArgs[&region] = regionLiveArgs;
}
};
Expand Down Expand Up @@ -619,7 +692,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// attributed to something else.
// Do (1') and (2').
if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
!hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
!hasLive(regionBranchOp->getResults(), nonLiveSet, liveSet, la)) {
cl.operations.push_back(regionBranchOp.getOperation());
return;
}
Expand Down Expand Up @@ -698,7 +771,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,

static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
DenseSet<Value> &liveSet, RDVFinalCleanupList &cl) {
LDBG() << "Processing branch op: " << *branchOp;
unsigned numSuccessors = branchOp->getNumSuccessors();

Expand All @@ -716,7 +789,7 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,

// Do (2)
BitVector successorNonLive =
markLives(operandValues, nonLiveSet, la).flip();
markLives(operandValues, nonLiveSet, liveSet, la).flip();
collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
successorNonLive);

Expand Down Expand Up @@ -876,26 +949,29 @@ void RemoveDeadValues::runOnOperation() {
// Tracks values eligible for erasure - complements liveness analysis to
// identify "droppable" values.
DenseSet<Value> deadVals;
// mark outgoing arguments to a public function LIVE. We also propagate
// liveness backward.
DenseSet<Value> liveVals;

// Maintains a list of Ops, values, branches, etc., slated for cleanup at the
// end of this pass.
RDVFinalCleanupList finalCleanupList;

module->walk([&](Operation *op) {
module->walk<WalkOrder::PostOrder, BackwardIterator>([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
processFuncOp(funcOp, module, la, deadVals, liveVals, finalCleanupList);
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
processRegionBranchOp(regionBranchOp, la, deadVals, liveVals,
finalCleanupList);
} else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
processBranchOp(branchOp, la, deadVals, finalCleanupList);
processBranchOp(branchOp, la, deadVals, liveVals, finalCleanupList);
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
// Nothing to do here because this is a terminator op and it should be
// honored with respect to its parent
} else if (isa<CallOpInterface>(op)) {
// Nothing to do because this op is associated with a function op and gets
// cleaned when the latter is cleaned.
processCallOp(cast<CallOpInterface>(op), module, la, deadVals, liveVals);
} else {
processSimpleOp(op, la, deadVals, finalCleanupList);
processSimpleOp(op, la, deadVals, liveVals, finalCleanupList);
}
});

Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Transforms/remove-dead-values.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,24 @@ module @return_void_with_unused_argument {
call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> ()
return %unused : memref<4xi32>
}

// the function signature is immutable because it is public.
func.func public @immutable_fn_with_unused_argument(%arg0: i32, %arg1: memref<4xf32>) -> () {
return
}

// CHECK-LABEL: func.func @main2
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<4xf32>
// CHECK: %[[UNUSED:.*]] = arith.constant 0 : i32
// CHECK: call @immutable_fn_with_unused_argument(%[[UNUSED]], %[[MEM]]) : (i32, memref<4xf32>) -> ()
func.func @main2() -> () {
%one = arith.constant 1 : i32
%scalar = arith.addi %one, %one: i32
%mem = memref.alloc() : memref<4xf32>

call @immutable_fn_with_unused_argument(%scalar, %mem) : (i32, memref<4xf32>) -> ()
return
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add two CFG examples where the blocks are listed in different order to ensure you're not sensitive to the order the blocks are in-memory.


// -----
Expand Down