From 110dc84193c43c285a6f07883cf40a94e1a3f0fc Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 3 Jul 2019 03:07:30 -0700 Subject: [PATCH 1/6] save lint some lint lint add charrnn save save save remove debug remove debug remove space refactor save rewrite dce --- python/tvm/relay/grammar/py3/RelayLexer.py | 4 +- python/tvm/relay/grammar/py3/RelayParser.py | 26 +- python/tvm/relay/grammar/py3/RelayVisitor.py | 2 +- python/tvm/relay/scope_builder.py | 3 +- src/relay/ir/alpha_equal.cc | 2 +- src/relay/pass/dead_code.cc | 163 ++++------ src/relay/pass/partial_eval.cc | 300 +++++++++++++----- src/relay/pass/to_graph_normal_form.cc | 2 +- .../relay/test_pass_dead_code_elimination.py | 39 ++- tests/python/relay/test_pass_partial_eval.py | 14 +- 10 files changed, 361 insertions(+), 194 deletions(-) diff --git a/python/tvm/relay/grammar/py3/RelayLexer.py b/python/tvm/relay/grammar/py3/RelayLexer.py index 80a0eba0db1a..6154049504c5 100644 --- a/python/tvm/relay/grammar/py3/RelayLexer.py +++ b/python/tvm/relay/grammar/py3/RelayLexer.py @@ -1,4 +1,4 @@ -# Generated from /workspace/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.1 +# Generated from /home/marisa/Work/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 from antlr4 import * from io import StringIO from typing.io import TextIO @@ -233,7 +233,7 @@ class RelayLexer(Lexer): def __init__(self, input=None, output:TextIO = sys.stdout): super().__init__(input, output) - self.checkVersion("4.7.1") + self.checkVersion("4.7.2") self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) self._actions = None self._predicates = None diff --git a/python/tvm/relay/grammar/py3/RelayParser.py b/python/tvm/relay/grammar/py3/RelayParser.py index a489580175d3..a30737644e8c 100644 --- a/python/tvm/relay/grammar/py3/RelayParser.py +++ b/python/tvm/relay/grammar/py3/RelayParser.py @@ -1,4 +1,4 @@ -# Generated from /workspace/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.1 +# Generated from /home/marisa/Work/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 # encoding: utf-8 from antlr4 import * from io import StringIO @@ -283,7 +283,7 @@ class RelayParser ( Parser ): def __init__(self, input:TokenStream, output:TextIO = sys.stdout): super().__init__(input, output) - self.checkVersion("4.7.1") + self.checkVersion("4.7.2") self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) self._predicates = None @@ -724,6 +724,8 @@ def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.Expr super().__init__(parser) self.copyFrom(ctx) + def SUB(self): + return self.getToken(RelayParser.SUB, 0) def expr(self): return self.getTypedRuleContext(RelayParser.ExprContext,0) @@ -867,6 +869,26 @@ def expr(self, i:int=None): else: return self.getTypedRuleContext(RelayParser.ExprContext,i) + def MUL(self): + return self.getToken(RelayParser.MUL, 0) + def DIV(self): + return self.getToken(RelayParser.DIV, 0) + def ADD(self): + return self.getToken(RelayParser.ADD, 0) + def SUB(self): + return self.getToken(RelayParser.SUB, 0) + def LT(self): + return self.getToken(RelayParser.LT, 0) + def GT(self): + return self.getToken(RelayParser.GT, 0) + def LE(self): + return self.getToken(RelayParser.LE, 0) + def GE(self): + return self.getToken(RelayParser.GE, 0) + def EQ(self): + return self.getToken(RelayParser.EQ, 0) + def NE(self): + return self.getToken(RelayParser.NE, 0) def accept(self, visitor:ParseTreeVisitor): if hasattr( visitor, "visitBinOp" ): diff --git a/python/tvm/relay/grammar/py3/RelayVisitor.py b/python/tvm/relay/grammar/py3/RelayVisitor.py index 30c802255c94..147dbb7975c7 100644 --- a/python/tvm/relay/grammar/py3/RelayVisitor.py +++ b/python/tvm/relay/grammar/py3/RelayVisitor.py @@ -1,4 +1,4 @@ -# Generated from /workspace/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.1 +# Generated from /home/marisa/Work/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 from antlr4 import * if __name__ is not None and "." in __name__: from .RelayParser import RelayParser diff --git a/python/tvm/relay/scope_builder.py b/python/tvm/relay/scope_builder.py index dfe3db187e07..16044c127e98 100644 --- a/python/tvm/relay/scope_builder.py +++ b/python/tvm/relay/scope_builder.py @@ -14,7 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""The scope builder interface """ + +"""The scope builder interface.""" from __future__ import absolute_import from . import expr as _expr diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index e16ffbbc3dd5..ea270277bb33 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -419,8 +419,8 @@ class AlphaEqualHandler: bool VisitExpr_(const LetNode* lhs, const Expr& other) final { if (const LetNode* rhs = other.as()) { - if (!ExprEqual(lhs->value, rhs->value)) return false; if (!MergeVarDecl(lhs->var, rhs->var)) return false; + if (!ExprEqual(lhs->value, rhs->value)) return false; return ExprEqual(lhs->body, rhs->body); } else { return false; diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index 54075f0699e6..b401dad33310 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -36,121 +36,94 @@ namespace tvm { namespace relay { +template +using VarMap = std::unordered_map; +using VarSet = std::unordered_set; + +class CalcDep; +class FindDef : private ExprVisitor { + private: + VarMap expr_map_; + + void VisitExpr_(const LetNode* l) final { + CHECK_EQ(expr_map_.count(l->var), 0); + expr_map_[l->var] = l->value; + VisitExpr(l->value); + VisitExpr(l->body); + } + + friend CalcDep; +}; + +class Eliminator : private ExprMutator { + private: + VarMap expr_map_; + VarMap use_map_; + bool inline_once_; + explicit Eliminator(const VarMap& expr_map, + const VarMap& use_map, + bool inline_once) : + expr_map_(expr_map), use_map_(use_map), inline_once_(inline_once) { } + friend CalcDep; + + bool HasLet(const Var& v) { + switch (use_map_[v]) { + case 0: + return false; + case 1: + return !inline_once_; + default: + return true; + } + } + + Expr VisitExpr_(const VarNode* op) final { + Var v = GetRef(op); + return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]); + } + + Expr VisitExpr_(const LetNode* op) final { + Var v = op->var; + if (HasLet(v)) { + return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body)); + } else { + return VisitExpr(op->body); + } + } +}; + // calculate the dependency graph from expression class CalcDep : private ExprVisitor { public: static Expr Eliminate(const Expr& e, bool inline_once) { - CalcDep cd; - cd.Calculate(e); - Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once); + FindDef fd; + fd(e); + CalcDep cd(fd.expr_map_); + cd(e); + Eliminator el(fd.expr_map_, cd.use_map_, inline_once); return el(e); } private: - template - using VarMap = std::unordered_map; - using VarSet = std::unordered_set; + explicit CalcDep(const VarMap& expr_map) : expr_map_(expr_map) { } VarMap expr_map_; VarMap use_map_; - VarSet letrec_set_; - bool count_ = true; - VarSet dead_worklist_; - VarSet current_letrec_; - - void LetRec(const std::function& func, const Var& v) { - current_letrec_.insert(v); - func(); - current_letrec_.erase(v); + + void VisitExpr(const Expr& e) final { + return ExprFunctor::VisitExpr(e); } void VisitExpr_(const LetNode* l) final { - if (count_) { - CHECK_EQ(expr_map_.count(l->var), 0); - CHECK_EQ(use_map_.count(l->var), 0); - expr_map_[l->var] = l->value; - use_map_[l->var] = 0; - dead_worklist_.insert(l->var); - LetRec([&]() { VisitExpr(l->value); }, l->var); - } VisitExpr(l->body); } - void VisitExpr(const Expr& e) final { - ExprFunctor::VisitExpr(e); - } - void VisitExpr_(const VarNode* v) final { Var var = GetRef(v); - if (expr_map_.count(var) == 0) { - return; - } - if (current_letrec_.count(var) == 0) { - if (count_) { - use_map_[var] += 1; - dead_worklist_.erase(var); - } else { - CHECK_GT(use_map_[var], 0) << var; - use_map_[var] -= 1; - if (use_map_[var] == 0) { - dead_worklist_.insert(var); - } - } - } else { - letrec_set_.insert(var); + ++use_map_[var]; + if (use_map_[var] == 1 && expr_map_.count(var) > 0) { + VisitExpr(expr_map_[var]); } } - - void Calculate(const Expr& v) { - VisitExpr(v); - count_ = false; - while (!dead_worklist_.empty()) { - Var dead = *(dead_worklist_.begin()); - dead_worklist_.erase(dead); - CHECK_EQ(use_map_[dead], 0); - if (expr_map_.count(dead) > 0) { - LetRec([&]() { VisitExpr(expr_map_[dead]); }, dead); - } - } - } - - class Eliminator : private ExprMutator { - private: - VarMap expr_map_; - VarMap use_map_; - VarSet letrec_set_; - bool inline_once_; - explicit Eliminator(const VarMap& expr_map, - const VarMap& use_map, - const VarSet& letrec_set, - bool inline_once) : - expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { } - friend CalcDep; - - bool HasLet(const Var& v) { - switch (use_map_[v]) { - case 0: - return false; - case 1: - return letrec_set_.count(v) > 0 || !inline_once_; - default: - return true; - } - } - - Expr VisitExpr_(const VarNode* op) final { - Var v = GetRef(op); - return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]); - } - - Expr VisitExpr_(const LetNode* op) final { - Var v = op->var; - if (HasLet(v)) { - return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body)); - } else { - return VisitExpr(op->body); - } - } - }; }; Expr DeadCodeElimination(const Expr& e, bool inline_once) { diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 869c056729e3..920a6e48bcaf 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -128,11 +128,11 @@ struct VarEqual { Expr PostProcess(const Expr&); -/*! \brief The base container type of Relay values. */ +/*! \brief A StaticNode contain some static data that the Partial Evaluator can use. */ class StaticNode : public RelayNode { public: static constexpr const char* _type_key = "relay.Static"; - TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode); + TVM_DECLARE_BASE_NODE_INFO(StaticNode, RelayNode); }; class Static : public NodeRef { @@ -174,7 +174,7 @@ struct STupleNode : StaticNode { TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode); }; -RELAY_DEFINE_NODE_REF(STuple, STupleNode, Value); +RELAY_DEFINE_NODE_REF(STuple, STupleNode, Static); Static MkSTuple(const std::vector& fields) { return Static(make_node(fields)); @@ -187,7 +187,7 @@ struct STensorNode : StaticNode { TVM_DECLARE_NODE_TYPE_INFO(STensorNode, StaticNode); }; -RELAY_DEFINE_NODE_REF(STensor, STensorNode, Value); +RELAY_DEFINE_NODE_REF(STensor, STensorNode, Static); Static MkSTensor(const NDArray& data) { return Static(make_node(data)); @@ -202,7 +202,7 @@ struct SConstructorNode : StaticNode { TVM_DECLARE_NODE_TYPE_INFO(SConstructorNode, StaticNode); }; -RELAY_DEFINE_NODE_REF(SConstructor, SConstructorNode, Value); +RELAY_DEFINE_NODE_REF(SConstructor, SConstructorNode, Static); Static MkSConstructor(const Constructor& constructor, const std::vector& fields) { return Static(make_node(constructor, fields)); @@ -214,13 +214,14 @@ struct SRefNode : StaticNode { TVM_DECLARE_NODE_TYPE_INFO(SRefNode, StaticNode); }; -RELAY_DEFINE_NODE_REF(SRef, SRefNode, Value); +RELAY_DEFINE_NODE_REF(SRef, SRefNode, Static); Static MkSRef() { return Static(make_node()); } -using Func = std::function&, +using Func = std::function&, const Attrs&, const Array&, LetList*)>; @@ -232,12 +233,145 @@ struct SFuncNode : StaticNode { TVM_DECLARE_NODE_TYPE_INFO(SFuncNode, StaticNode); }; -RELAY_DEFINE_NODE_REF(SFunc, SFuncNode, Value); +RELAY_DEFINE_NODE_REF(SFunc, SFuncNode, Static); Static MkSFunc(const Func& func) { return Static(make_node(func)); } + +class FuelNode; +/*! \brief A join-semilattice with finite ascending chain. + * It mean that we can join two element to get an element, + * and for every element, there is only finite amount of join before getting back the same element. + * + * Every time we recurse, we do a join and require that progress must be made. + * This make sure we do not recurse infinitely in the Partial Evaluator. + */ +class Fuel : public NodeRef { + public: + Fuel() {} + explicit Fuel(NodePtr n) : NodeRef(n) {} + const FuelNode* operator->() const; + + using ContainerType = FuelNode; +}; + +class FuelNode : public RelayNode { + public: + // Please implement one of the following function or there will be infinite loop. + /*! \brief return the new Fuel, and whether progress is made. + * + * Note that progress is not symmetric - it only measure progress for (*this). + * + * Thus, if the generated is smaller then the argument of Join, + * and the generated is not smaller then (*this), + * progress should be false. + */ + virtual std::tuple Join(const Fuel& f) const { + bool progress = false; + auto ret = Join(f, &progress); + return std::make_tuple(ret, progress); + } + /*! \brief return the new Fuel, and write true only iff progress is made. */ + virtual Fuel Join(const Fuel& f, bool* progress) const { + CHECK(progress); + auto ret = Join(f); + *progress |= std::get<1>(ret); + return std::get<0>(ret); + } + static constexpr const char* _type_key = "relay.Fuel"; + TVM_DECLARE_BASE_NODE_INFO(FuelNode, RelayNode); +}; + +const FuelNode* Fuel::operator->() const { + return static_cast(node_.get()); +} + +Fuel MkFSeq(const std::vector& fuels); +struct FSeqNode : FuelNode { + std::vector fuels; + virtual Fuel Join(const Fuel& f, bool* progress) const final { + auto x = f.as(); + CHECK(x); + CHECK_EQ(fuels.size(), x->fuels.size()); + std::vector new_fuels; + for (size_t i = 0; i < fuels.size(); ++i) { + new_fuels.push_back(fuels[i]->Join(x->fuels[i], progress)); + } + return MkFSeq(new_fuels); + } + explicit FSeqNode(const std::vector& fuels) : fuels(fuels) { } + static constexpr const char* _type_key = "relay.FSeq"; + TVM_DECLARE_NODE_TYPE_INFO(FSeqNode, FuelNode); +}; + +RELAY_DEFINE_NODE_REF(FSeq, FSeqNode, Fuel); + +Fuel MkFSeq(const std::vector& fuels) { + return Fuel(make_node(fuels)); +} + +Fuel MkFTime(Time time); +struct FTimeNode : FuelNode { + Time time; + virtual std::tuple Join(const Fuel& f) const final { + auto x = f.as(); + CHECK(x); + Time new_time = std::min(time, x->time); + return std::make_tuple(MkFTime(new_time), new_time < time); + } + explicit FTimeNode(Time time) : time(time) { } + static constexpr const char* _type_key = "relay.FTime"; + TVM_DECLARE_NODE_TYPE_INFO(FTimeNode, FuelNode); +}; + +RELAY_DEFINE_NODE_REF(FTime, FTimeNode, Fuel); + +Fuel MkFTime(Time time) { + return Fuel(make_node(time)); +} + +Fuel MkFTValue(size_t tvalue); +/*! \brief If the pstatic is hold a positive integer scalar, that number, else 0. */ +struct FTValueNode : FuelNode { + size_t tvalue; + virtual std::tuple Join(const Fuel& f) const final { + auto x = f.as(); + CHECK(x); + size_t new_tvalue = std::min(tvalue, x->tvalue); + return std::make_tuple(MkFTValue(new_tvalue), new_tvalue < tvalue); + } + explicit FTValueNode(size_t tvalue) : tvalue(tvalue) { } + static constexpr const char* _type_key = "relay.FTValue"; + TVM_DECLARE_NODE_TYPE_INFO(FTValueNode, FuelNode); +}; + +RELAY_DEFINE_NODE_REF(FTValue, FTValueNode, Fuel); + +Fuel MkFTValue(size_t tvalue) { + return Fuel(make_node(tvalue)); +} + +/*! \brief Initially every element has Fuel of FBottom. It is the smallest element. + * + * Note that it is illegal to has FBottom inside some other Fuel - + * doing so break the finite ascending chain property. + */ +struct FBottomNode : FuelNode { + virtual std::tuple Join(const Fuel& f) const final { + return std::make_tuple(f, !f.as()); + } + static constexpr const char* _type_key = "relay.FBottom"; + TVM_DECLARE_NODE_TYPE_INFO(FBottomNode, FuelNode); +}; + +RELAY_DEFINE_NODE_REF(FBottom, FBottomNode, Fuel); + +Fuel MkFBottom() { + return Fuel(make_node()); +} + /*! * \brief A stack frame in the Relay interpreter. * @@ -469,6 +603,18 @@ class PartialEvaluator : public ExprFunctor return ret; } + PStatic VisitExpr(const Expr& e, LetList* ll, const Var& name) { + if (auto* op = e.as()) { + if (op->op.same_as(WithFuncIdOp())) { + CHECK_EQ(op->args.size(), 1); + return VisitExpr(op->args[0], ll, name); + } + } + PStatic ret = e.as() ? VisitFunc(Downcast(e), ll, name) : VisitExpr(e, ll); + CHECK(IsAtomic(ret->dynamic)) << ret->dynamic; + return ret; + } + PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final { return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef(op))); } @@ -504,7 +650,7 @@ class PartialEvaluator : public ExprFunctor InitializeFuncId(func); Func f = VisitFuncStatic(func, gv); gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); - func = AsFunc(PostProcess(VisitFuncDynamic(func, f))); + func = AsFunc(PostProcess(VisitFuncDynamic(func, f, gv))); mod_->Update(gv, func); } return gv_map_.at(gv); @@ -515,7 +661,7 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr_(const LetNode* op, LetList* ll) final { - env_.Insert(op->var, VisitExpr(op->value, ll)); + env_.Insert(op->var, VisitExpr(op->value, ll, op->var)); return VisitExpr(op->body, ll); } @@ -588,34 +734,53 @@ class PartialEvaluator : public ExprFunctor x_dyn.push_back(ps->dynamic); } if (f->pstatic.defined()) { - return Downcast(f->pstatic)->func(x, op->attrs, op->type_args, ll); + return Downcast(f->pstatic)->func(f, x, op->attrs, op->type_args, ll); } else { store_.Invalidate(); return NoStatic(ll->Push(CallNode::make(f->dynamic, x_dyn, op->attrs, op->type_args))); } } - struct TimeFrame { + struct FuelFrame { PartialEvaluator* pe_; FuncId fid_; - std::vector