Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/tvm/relay/scope_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The scope builder interface """

"""The scope builder interface."""
from __future__ import absolute_import

from . import expr as _expr
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/alpha_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,8 @@ class AlphaEqualHandler:

bool VisitExpr_(const LetNode* lhs, const Expr& other) final {
if (const LetNode* rhs = other.as<LetNode>()) {
if (!ExprEqual(lhs->value, rhs->value)) return false;
if (!MergeVarDecl(lhs->var, rhs->var)) return false;
if (!ExprEqual(lhs->value, rhs->value)) return false;
return ExprEqual(lhs->body, rhs->body);
} else {
return false;
Expand Down
163 changes: 68 additions & 95 deletions src/relay/pass/dead_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,121 +36,94 @@
namespace tvm {
namespace relay {

template<typename X>
using VarMap = std::unordered_map<Var, X, NodeHash, NodeEqual>;
using VarSet = std::unordered_set<Var, NodeHash, NodeEqual>;

class CalcDep;
class FindDef : private ExprVisitor {
private:
VarMap<Expr> expr_map_;

void VisitExpr_(const LetNode* l) final {
CHECK_EQ(expr_map_.count(l->var), 0);
expr_map_[l->var] = l->value;
VisitExpr(l->value);
VisitExpr(l->body);
}

friend CalcDep;
};

class Eliminator : private ExprMutator {
private:
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
bool inline_once_;
explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map,
bool inline_once) :
expr_map_(expr_map), use_map_(use_map), inline_once_(inline_once) { }
friend CalcDep;

bool HasLet(const Var& v) {
switch (use_map_[v]) {
case 0:
return false;
case 1:
return !inline_once_;
default:
return true;
}
}

Expr VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]);
}

Expr VisitExpr_(const LetNode* op) final {
Var v = op->var;
if (HasLet(v)) {
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
} else {
return VisitExpr(op->body);
}
}
};

// calculate the dependency graph from expression
class CalcDep : private ExprVisitor {
public:
static Expr Eliminate(const Expr& e, bool inline_once) {
CalcDep cd;
cd.Calculate(e);
Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once);
FindDef fd;
fd(e);
CalcDep cd(fd.expr_map_);
cd(e);
Eliminator el(fd.expr_map_, cd.use_map_, inline_once);
return el(e);
}

private:
template<typename X>
using VarMap = std::unordered_map<Var, X, NodeHash, NodeEqual>;
using VarSet = std::unordered_set<Var, NodeHash, NodeEqual>;
explicit CalcDep(const VarMap<Expr>& expr_map) : expr_map_(expr_map) { }
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
VarSet letrec_set_;
bool count_ = true;
VarSet dead_worklist_;
VarSet current_letrec_;

void LetRec(const std::function<void()>& func, const Var& v) {
current_letrec_.insert(v);
func();
current_letrec_.erase(v);

void VisitExpr(const Expr& e) final {
return ExprFunctor<void(const Expr& e)>::VisitExpr(e);
}

void VisitExpr_(const LetNode* l) final {
if (count_) {
CHECK_EQ(expr_map_.count(l->var), 0);
CHECK_EQ(use_map_.count(l->var), 0);
expr_map_[l->var] = l->value;
use_map_[l->var] = 0;
dead_worklist_.insert(l->var);
LetRec([&]() { VisitExpr(l->value); }, l->var);
}
VisitExpr(l->body);
}

void VisitExpr(const Expr& e) final {
ExprFunctor<void(const Expr&)>::VisitExpr(e);
}

void VisitExpr_(const VarNode* v) final {
Var var = GetRef<Var>(v);
if (expr_map_.count(var) == 0) {
return;
}
if (current_letrec_.count(var) == 0) {
if (count_) {
use_map_[var] += 1;
dead_worklist_.erase(var);
} else {
CHECK_GT(use_map_[var], 0) << var;
use_map_[var] -= 1;
if (use_map_[var] == 0) {
dead_worklist_.insert(var);
}
}
} else {
letrec_set_.insert(var);
++use_map_[var];
if (use_map_[var] == 1 && expr_map_.count(var) > 0) {
VisitExpr(expr_map_[var]);
}
}

void Calculate(const Expr& v) {
VisitExpr(v);
count_ = false;
while (!dead_worklist_.empty()) {
Var dead = *(dead_worklist_.begin());
dead_worklist_.erase(dead);
CHECK_EQ(use_map_[dead], 0);
if (expr_map_.count(dead) > 0) {
LetRec([&]() { VisitExpr(expr_map_[dead]); }, dead);
}
}
}

class Eliminator : private ExprMutator {
private:
VarMap<Expr> expr_map_;
VarMap<size_t> use_map_;
VarSet letrec_set_;
bool inline_once_;
explicit Eliminator(const VarMap<Expr>& expr_map,
const VarMap<size_t>& use_map,
const VarSet& letrec_set,
bool inline_once) :
expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { }
friend CalcDep;

bool HasLet(const Var& v) {
switch (use_map_[v]) {
case 0:
return false;
case 1:
return letrec_set_.count(v) > 0 || !inline_once_;
default:
return true;
}
}

Expr VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]);
}

Expr VisitExpr_(const LetNode* op) final {
Var v = op->var;
if (HasLet(v)) {
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
} else {
return VisitExpr(op->body);
}
}
};
};

Expr DeadCodeElimination(const Expr& e, bool inline_once) {
Expand Down
Loading