From 76aac973b6184656a2ef89bc9bdbf0d94dd454e7 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 27 Feb 2017 21:25:48 -0800 Subject: [PATCH 1/7] [ARITH/VISITOR] Modular Analysis, ExprFunctor, StmtFunctor --- include/tvm/ir_functor_ext.h | 261 ++++++++++++++++++++ include/tvm/ir_mutator.h | 62 ----- include/tvm/ir_visitor.h | 47 +++- python/tvm/api.py | 38 ++- python/tvm/arith.py | 6 +- src/api/api_arith.cc | 6 + src/arithmetic/int_set_internal.h | 18 ++ src/arithmetic/modular.cc | 159 ++++++++++++ src/arithmetic/modular.h | 61 +++++ src/pass/ir_mutator.cc | 98 +------- tests/cpp/ir_functor_test.cc | 62 ++++- tests/python/unittest/test_arith_intset.py | 5 + tests/python/unittest/test_arith_modular.py | 32 +++ 13 files changed, 686 insertions(+), 169 deletions(-) create mode 100644 include/tvm/ir_functor_ext.h create mode 100644 src/arithmetic/modular.cc create mode 100644 src/arithmetic/modular.h create mode 100644 tests/python/unittest/test_arith_modular.py diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h new file mode 100644 index 000000000000..9fce6ae9dcf5 --- /dev/null +++ b/include/tvm/ir_functor_ext.h @@ -0,0 +1,261 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file ir_functor_ext.h + * \brief More powerful Visitor that allows define function signatures. + */ +#ifndef TVM_IR_FUNCTOR_EXT_H_ +#define TVM_IR_FUNCTOR_EXT_H_ + +#include +#include "./ir.h" + +namespace tvm { +namespace ir { + +/*! + * \brief A dynamical functor that dispatches on in the first Expr argument. + * You can use this as a more powerful Visitor, since it allows you to + * define function signatures of Visit Function. + * + * \code + * // A functor that set variable to b. and calculate results. + * class MyExprFunctor + * : public ir::ExprFunctor { + * public: + * int VisitExpr_(const Variable* op, int b) final { + * return b; + * } + * int VisitExpr_(const IntImm* op, int b) final { + * return op->value; + * } + * int VisitExpr_(const Add* op, int b) final { + * return Visit(op->a, b) + Visit(op->b, b); + * } + * }; + * MyExprFunctor f; + * Var x("x"); + * CHECK_EQ(f(x + 1, 2), 3); + * \endcode + * + * \note Why do we need this more powerful Functor: + * + * We often need to implement a transformer tasks. + * Say we want to take Expr and transform it to some analysis result, + * This easily be done incorrectly using plain Visitor. See IRVisitor's + * document for possible error cases. + * + * \tparam FType function signiture + * This type if only defined for FType with function signiture R(const Expr&, Args...) + */ +template +class ExprFunctor; +/*! + * \brief Same as ExprFunctor except it is applied on statements + * \tparam FType The function signature. + */ +template +class StmtFunctor; + +// functions to be overriden. +#define EXPR_FUNCTOR_DEFAULT { \ + return VisitExprDefault_(op, std::forward(args)...); \ + } +#define STMT_FUNCTOR_DEFAULT { \ + return VisitStmtDefault_(op, std::forward(args)...); \ +} + +#define IR_EXPR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const NodeRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.node_.get()), \ + std::forward(args)...); \ + }); \ + +#define IR_STMT_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const NodeRef& n, TSelf* self, Args... args) { \ + return self->VisitStmt_(static_cast(n.node_.get()), \ + std::forward(args)...); \ + }); \ + +template +class ExprFunctor { + private: + using TSelf = ExprFunctor; + using FType = IRFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~ExprFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Expr& n, Args... args) { + return VisitExpr(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitExpr(const Expr& n, Args... args) { + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitExpr_(const Variable* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Load* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Let* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Call* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Add* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Sub* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Mul* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Div* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Mod* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Max* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const NE* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LT* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LE* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GT* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GE* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const And* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Or* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Reduce* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Cast* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Not* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const Broadcast* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const IntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExprDefault_(const Node* op, Args ...) { + LOG(FATAL) << "Do not have a default for " << op->type_key(); + return R(); + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + IR_EXPR_FUNCTOR_DISPATCH(Variable); + IR_EXPR_FUNCTOR_DISPATCH(Load); + IR_EXPR_FUNCTOR_DISPATCH(Let); + IR_EXPR_FUNCTOR_DISPATCH(Call); + IR_EXPR_FUNCTOR_DISPATCH(Add); + IR_EXPR_FUNCTOR_DISPATCH(Sub); + IR_EXPR_FUNCTOR_DISPATCH(Mul); + IR_EXPR_FUNCTOR_DISPATCH(Div); + IR_EXPR_FUNCTOR_DISPATCH(Mod); + IR_EXPR_FUNCTOR_DISPATCH(Min); + IR_EXPR_FUNCTOR_DISPATCH(Max); + IR_EXPR_FUNCTOR_DISPATCH(EQ); + IR_EXPR_FUNCTOR_DISPATCH(NE); + IR_EXPR_FUNCTOR_DISPATCH(LT); + IR_EXPR_FUNCTOR_DISPATCH(LE); + IR_EXPR_FUNCTOR_DISPATCH(GT); + IR_EXPR_FUNCTOR_DISPATCH(GE); + IR_EXPR_FUNCTOR_DISPATCH(And); + IR_EXPR_FUNCTOR_DISPATCH(Or); + IR_EXPR_FUNCTOR_DISPATCH(Reduce); + IR_EXPR_FUNCTOR_DISPATCH(Cast); + IR_EXPR_FUNCTOR_DISPATCH(Not); + IR_EXPR_FUNCTOR_DISPATCH(Select); + IR_EXPR_FUNCTOR_DISPATCH(Ramp); + IR_EXPR_FUNCTOR_DISPATCH(Broadcast); + IR_EXPR_FUNCTOR_DISPATCH(IntImm); + IR_EXPR_FUNCTOR_DISPATCH(UIntImm); + IR_EXPR_FUNCTOR_DISPATCH(FloatImm); + IR_EXPR_FUNCTOR_DISPATCH(StringImm); + return vtable; + } +}; + +template +class StmtFunctor { + private: + using TSelf = StmtFunctor; + using FType = IRFunctor; + + public: + /*! \brief the result type of this functor */ + using result_type = R; + /*! \brief virtual destructor */ + virtual ~StmtFunctor() {} + /*! + * \brief Same as call. + * \param n The stmt node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const Stmt& n, Args... args) { + return VisitStmt(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The stmt node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitStmt(const Stmt& n, Args... args) { + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitStmt_(const LetStmt* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const AttrStmt* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const IfThenElse* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const For* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const Allocate* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const Store* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const Free* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const AssertStmt* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const ProducerConsumer* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmtDefault_(const Node* op, Args ...) { + LOG(FATAL) << "Do not have a default for " << op->type_key(); + return R(); + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + IR_STMT_FUNCTOR_DISPATCH(LetStmt); + IR_STMT_FUNCTOR_DISPATCH(AttrStmt); + IR_STMT_FUNCTOR_DISPATCH(IfThenElse); + IR_STMT_FUNCTOR_DISPATCH(For); + IR_STMT_FUNCTOR_DISPATCH(Allocate); + IR_STMT_FUNCTOR_DISPATCH(Store); + IR_STMT_FUNCTOR_DISPATCH(Free); + IR_STMT_FUNCTOR_DISPATCH(AssertStmt); + IR_STMT_FUNCTOR_DISPATCH(ProducerConsumer); + IR_STMT_FUNCTOR_DISPATCH(Provide); + IR_STMT_FUNCTOR_DISPATCH(Realize); + IR_STMT_FUNCTOR_DISPATCH(Block); + IR_STMT_FUNCTOR_DISPATCH(Evaluate); + return vtable; + } +}; + +#undef IR_STMT_FUNCTOR_DISPATCH +#undef IR_EXPR_FUNCTOR_DISPATCH +#undef EXPR_FUNCTOR_DEFAULT +#undef STMT_FUNCTOR_DEFAULT + +} // namespace ir +} // namespace tvm +#endif // TVM_IR_FUNCTOR_EXT_H_ diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index c428232698e8..1a84cb24a1e8 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -55,59 +55,23 @@ class IRMutator { static FMutateStmt& vtable_stmt(); // NOLINT(*) // Set of overloadable functions // The underscore allows Mutate not to be shadowed by inheritance - virtual Stmt Mutate_(const Variable* op, const Stmt& s); virtual Stmt Mutate_(const LetStmt* op, const Stmt& s); virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s); virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s); virtual Stmt Mutate_(const For* op, const Stmt& s); virtual Stmt Mutate_(const Allocate* op, const Stmt& s); - virtual Stmt Mutate_(const Load* op, const Stmt& s); virtual Stmt Mutate_(const Store* op, const Stmt& s); - virtual Stmt Mutate_(const Let* op, const Stmt& s); virtual Stmt Mutate_(const Free* op, const Stmt& s); - virtual Stmt Mutate_(const Call* op, const Stmt& s); - virtual Stmt Mutate_(const Add* op, const Stmt& e); - virtual Stmt Mutate_(const Sub* op, const Stmt& e); - virtual Stmt Mutate_(const Mul* op, const Stmt& e); - virtual Stmt Mutate_(const Div* op, const Stmt& e); - virtual Stmt Mutate_(const Mod* op, const Stmt& e); - virtual Stmt Mutate_(const Min* op, const Stmt& e); - virtual Stmt Mutate_(const Max* op, const Stmt& e); - virtual Stmt Mutate_(const EQ* op, const Stmt& e); - virtual Stmt Mutate_(const NE* op, const Stmt& e); - virtual Stmt Mutate_(const LT* op, const Stmt& e); - virtual Stmt Mutate_(const LE* op, const Stmt& e); - virtual Stmt Mutate_(const GT* op, const Stmt& e); - virtual Stmt Mutate_(const GE* op, const Stmt& e); - virtual Stmt Mutate_(const And* op, const Stmt& e); - virtual Stmt Mutate_(const Or* op, const Stmt& e); - virtual Stmt Mutate_(const Reduce* op, const Stmt& s); - virtual Stmt Mutate_(const Cast* op, const Stmt& s); - virtual Stmt Mutate_(const Not* op, const Stmt& s); - virtual Stmt Mutate_(const Select* op, const Stmt& s); - virtual Stmt Mutate_(const Ramp* op, const Stmt& s); - virtual Stmt Mutate_(const Broadcast* op, const Stmt& e); virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e); virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e); virtual Stmt Mutate_(const Provide* op, const Stmt& e); virtual Stmt Mutate_(const Realize* op, const Stmt& s); virtual Stmt Mutate_(const Block* op, const Stmt& s); virtual Stmt Mutate_(const Evaluate* op, const Stmt& e); - virtual Stmt Mutate_(const IntImm* op, const Stmt& e); - virtual Stmt Mutate_(const UIntImm* op, const Stmt& e); - virtual Stmt Mutate_(const FloatImm* op, const Stmt& e); - virtual Stmt Mutate_(const StringImm* op, const Stmt& e); virtual Expr Mutate_(const Variable* op, const Expr& e); - virtual Expr Mutate_(const LetStmt* op, const Expr& e); - virtual Expr Mutate_(const AttrStmt* op, const Expr& e); - virtual Expr Mutate_(const IfThenElse* op, const Expr& e); - virtual Expr Mutate_(const For* op, const Expr& e); - virtual Expr Mutate_(const Allocate* op, const Expr& e); virtual Expr Mutate_(const Load* op, const Expr& e); - virtual Expr Mutate_(const Store* op, const Expr& e); virtual Expr Mutate_(const Let* op, const Expr& e); - virtual Expr Mutate_(const Free* op, const Expr& e); virtual Expr Mutate_(const Call* op, const Expr& e); virtual Expr Mutate_(const Add* op, const Expr& e); virtual Expr Mutate_(const Sub* op, const Expr& e); @@ -130,38 +94,12 @@ class IRMutator { virtual Expr Mutate_(const Select* op, const Expr& e); virtual Expr Mutate_(const Ramp* op, const Expr& e); virtual Expr Mutate_(const Broadcast* op, const Expr& e); - virtual Expr Mutate_(const AssertStmt* op, const Expr& e); - virtual Expr Mutate_(const ProducerConsumer* op, const Expr& e); - virtual Expr Mutate_(const Provide* op, const Expr& e); - virtual Expr Mutate_(const Realize* op, const Expr& e); - virtual Expr Mutate_(const Block* op, const Expr& e); - virtual Expr Mutate_(const Evaluate* op, const Expr& e); virtual Expr Mutate_(const IntImm* op, const Expr& e); virtual Expr Mutate_(const UIntImm* op, const Expr& e); virtual Expr Mutate_(const FloatImm* op, const Expr& e); virtual Expr Mutate_(const StringImm* op, const Expr& e); }; -/*! - * \brief Example on how to subclass and override behavior of IRMutator - */ -class IRMutatorExample : public IRMutator { - public: - Expr Mutate(Expr expr) final { - static const FMutateExpr& f = IRMutatorExample::vtable_expr(); - return (f.can_dispatch(expr) ? - f(expr, expr, this) : IRMutator::Mutate(expr)); - } - Stmt Mutate(Stmt stmt) final { - static const FMutateStmt& f = IRMutatorExample::vtable_stmt(); - return (f.can_dispatch(stmt) ? - f(stmt, stmt, this) : IRMutator::Mutate(stmt)); - } - // to be implemented by child class - static FMutateExpr& vtable_expr(); // NOLINT(*) - static FMutateStmt& vtable_stmt(); // NOLINT(*) -}; - } // namespace ir } // namespace tvm #endif // TVM_IR_MUTATOR_H_ diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index 6bfbce25a0df..712f865aa2f7 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -6,6 +6,7 @@ #ifndef TVM_IR_VISITOR_H_ #define TVM_IR_VISITOR_H_ +#include #include "./ir.h" namespace tvm { @@ -17,7 +18,51 @@ namespace ir { * This IRVisitor is implemented via IRFunctor * This enables extensions of possible new Node. * - * \sa IRFunctor, PostOrderVisit + * \sa ExprFunctor, StmtFunctor, PostOrderVisit + * + * \note If you need to return values during Visit: + * - If it is mutaion of the IR, use IRMutator + * - If you want to return other things, consider use ExprFunctor/StmtFunctor + * - Watch out for possible bug pattern if you use IRVisitor to simulate returns. + * + * \code + * + * // This is an example code to show cases for traps in IRVisitor + * // The use case is to count number of Variables in the ir tree. + * class MyCounter : public IRVisitor { + * public: + * int Count(const NodeRef& n) { + * ret_ = 0; + * this->Visit(n); + * return ret_; + * } + * void Visit_(const Variable* op) final { + * ret_ = 1; + * } + * void Visit_(const Add* op) final { + * ret_ = count(op->a) + count(op->b); + * } + + * private: + * int ret_; + * }; + * MyCounter counter; + * Var x("x"); + * // this returns 2 + * CHECK_EQ(counter.Count(x + x), 2); + * // Think what is the result of the following count + * counter.count(Max::make(x, x)); + * // The result is actually 1 + * // This is because Visit is not overriden for Max + * // so it simply calls Visit for the left and right children + * // and because Count is not called, ret_ is not cleared. + * // There can also be cases where ret_ is forgetten to be set. + * + * // These traps may not happen if we program carefully + * // But it is recommended to use ExprFunctor, which allows direct + * // return the value, this helps us to avoid such problems. + * \encode + * */ class IRVisitor { public: diff --git a/python/tvm/api.py b/python/tvm/api.py index d6c81bac69e3..72960dae628d 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -274,33 +274,51 @@ def sum(expr, axis): return x -def min(expr, axis): - """Create a min expression over axis +def min(lhs, rhs=None, axis=None): + """Create a min expression. Parameters ---------- - expr : Expr - The source expression. + lhs : Expr + The left hand expression. - axis : IterVar + rhs : Expr, optional + The right hand expression. + + axis : IterVar, optional The reduction IterVar axis """ + if rhs and axis: + raise ValueError("Can only take one argument, rhs or axis") + if isinstance(rhs, (_collections.IterVar, list)): + axis, rhs = rhs, axis + if rhs: + return _make.Min(lhs, rhs) axis = axis if isinstance(axis, list) else [axis] x = _make.Reduce("Min", expr, axis) return x -def max(expr, axis): - """Create a min expression over axis +def max(lhs, rhs=None, axis=None): + """Create a max expression. Parameters ---------- - expr : Expr - The source expression. + lhs : Expr + The left hand expression. - axis : IterVar + rhs : Expr, optional + The right hand expression. + + axis : IterVar, optional The reduction IterVar axis """ + if rhs and axis: + raise ValueError("Can only take one argument, rhs or axis") + if isinstance(rhs, (_collections.IterVar, list)): + axis, rhs = rhs, axis + if rhs: + return _make.Max(lhs, rhs) axis = axis if isinstance(axis, list) else [axis] x = _make.Reduce("Max", expr, axis) return x diff --git a/python/tvm/arith.py b/python/tvm/arith.py index c3fad6670749..3440656b7dd4 100644 --- a/python/tvm/arith.py +++ b/python/tvm/arith.py @@ -5,7 +5,6 @@ from ._ctypes._node import NodeBase, register_node from . import _api_internal -@register_node class IntSet(NodeBase): """Represent a set of integer in one dimension.""" def is_nothing(self): @@ -33,3 +32,8 @@ def max(self): class StrideSet(IntSet): """Represent set of strided integers""" pass + +@register_node +class ModularSet(IntSet): + """Represent range of (coeff * x + base) for x in Z """ + pass diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 7edbe3eec2a8..1866d9b49970 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -7,6 +7,7 @@ #include #include #include "../arithmetic/int_set.h" +#include "../arithmetic/modular.h" namespace tvm { namespace arith { @@ -21,6 +22,11 @@ TVM_REGISTER_API(_arith_intset_interval) *ret = IntSet::interval(args[0], args[1]); }); +TVM_REGISTER_API(_arith_EvalModular) +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = EvalModular(args[0], Map()); + }); + TVM_REGISTER_API(_arith_DeduceBound) .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = DeduceBound(args[0], args[1], args[2]); diff --git a/src/arithmetic/int_set_internal.h b/src/arithmetic/int_set_internal.h index f0fb709ce885..7e57cb5cbcf6 100644 --- a/src/arithmetic/int_set_internal.h +++ b/src/arithmetic/int_set_internal.h @@ -9,6 +9,7 @@ #include #include #include "./int_set.h" +#include "./modular.h" namespace tvm { namespace arith { @@ -54,6 +55,23 @@ struct StrideSet : public IntSetNode { TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode); }; +/*! + * \brief Set represented by range of ModularEntry. + * Used for front-end modular analysis. + */ +struct ModularSet : public IntSetNode { + /*! \brief Internal modular entry */ + ModularEntry e; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("base", &(e.base)); + v->Visit("coeff", &(e.coeff)); + } + static constexpr const char* _type_key = "ModularSet"; + TVM_DECLARE_NODE_TYPE_INFO(ModularSet, IntSetNode); +}; + + } // namespace arith } // namespace tvm diff --git a/src/arithmetic/modular.cc b/src/arithmetic/modular.cc new file mode 100644 index 000000000000..c487701064f9 --- /dev/null +++ b/src/arithmetic/modular.cc @@ -0,0 +1,159 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file modular.cc + * \brief Modular analysis + */ +#include +#include +#include +#include +#include "./modular.h" +#include "./int_set_internal.h" + +namespace tvm { +namespace arith { + +using namespace ir; + +class ModularEvaluator + : public ExprFunctor { + public: + explicit ModularEvaluator( + const std::unordered_map< + const Variable*, ModularEntry>& mod_map) + : mod_map_(mod_map) { + } + ModularEntry Eval(const Expr& e) { + return VisitExpr(e); + } + // default + ModularEntry VisitExprDefault_(const Node*) final { + return ModularEntry::everything(); + } + // override combination rules. + ModularEntry VisitExpr_(const IntImm* op) final { + if (op->value < std::numeric_limits::max()) { + ModularEntry ret; + ret.base = static_cast(op->value); + ret.coeff = 0; + return ret; + } else { + return ModularEntry::everything(); + } + } + ModularEntry VisitExpr_(const UIntImm* op) final { + if (op->value < static_cast( + std::numeric_limits::max())) { + ModularEntry ret; + ret.base = static_cast(op->value); + ret.coeff = 0; + return ret; + } else { + return ModularEntry::everything(); + } + } + ModularEntry VisitExpr_(const Variable* op) final { + auto it = mod_map_.find(op); + if (it != mod_map_.end()) { + return it->second; + } else { + return ModularEntry::everything(); + } + } + ModularEntry VisitExpr_(const Add* op) final { + ModularEntry a = Eval(op->a); + ModularEntry b = Eval(op->b); + ModularEntry ret; + ret.coeff = ZeroAwareGCD(a.coeff, b.coeff); + ret.base = BaseSimplify(a.base + b.base, ret.coeff); + return ret; + } + ModularEntry VisitExpr_(const Sub* op) final { + ModularEntry a = Eval(op->a); + ModularEntry b = Eval(op->b); + ModularEntry ret; + ret.coeff = ZeroAwareGCD(a.coeff, b.coeff); + ret.base = BaseSimplify(a.base - b.base, ret.coeff); + return ret; + } + ModularEntry VisitExpr_(const Mul* op) final { + ModularEntry a = Eval(op->a); + ModularEntry b = Eval(op->b); + // Simplification rule, x, y, z are in Z + // (p x + n) (q y + m) + // -> pq xy + pm x + qn y + mn + // -> pq z + pm x + qn y + mn + int pq = a.coeff * b.coeff; + int pm = a.coeff * b.base; + int qn = a.base * b.coeff; + ModularEntry ret; + ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn)); + ret.base = BaseSimplify(a.base * b.base, ret.coeff); + return ret; + } + ModularEntry VisitExpr_(const Div* op) final { + // a c x / c -> a x + // We cannot do cases where offset is non-zero + // because of different integer rounding in pos/neg + ModularEntry a = Eval(op->a); + ModularEntry b = Eval(op->b); + if (b.coeff == 0 && + a.base == 0) { + CHECK_NE(b.base, 0); + if (a.coeff % b.base == 0) { + ModularEntry ret; + ret.coeff = a.coeff / b.base; + ret.base = 0; + return ret; + } + } + return ModularEntry::everything(); + } + + private: + const std::unordered_map< + const Variable*, ModularEntry>& mod_map_; + + // simplify the base by putting it in range. + static int BaseSimplify(int base, int coeff) { + if (coeff == 0) return base; + base = base % coeff; + if (base < 0) base += coeff; + return base; + } + static int ZeroAwareGCD(int a, int b) { + CHECK_GE(a, 0); + CHECK_GE(b, 0); + if (a < b) std::swap(a, b); + if (b == 0) return a; + // perform GCD (greatest common divisor) + // ax + by = gcd(a, b) z if a != 0, b != 0 + while (a % b != 0) { + a = a % b; + std::swap(a, b); + } + return b; + } +}; + +ModularEntry EvalModular( + const Expr& e, + const std::unordered_map& mod_map) { + return ModularEvaluator(mod_map)(e); +} + +IntSet EvalModular(const Expr& e, + const Map& mod_map) { + std::unordered_map mmap; + for (auto& kv : mod_map) { + const ModularSet* m = kv.second.as(); + CHECK(m) << "Need to pass ModularSet for Modular Analysis"; + mmap[kv.first.get()] = m->e; + } + std::shared_ptr n = std::make_shared(); + n->e = ModularEvaluator(mmap)(e); + return IntSet(n); +} + +} // namespace arith +} // namespace tvm diff --git a/src/arithmetic/modular.h b/src/arithmetic/modular.h new file mode 100644 index 000000000000..bb51901a65f3 --- /dev/null +++ b/src/arithmetic/modular.h @@ -0,0 +1,61 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file modular.h + * \brief Modular integer set analysis + */ +#ifndef TVM_ARITHMETIC_MODULAR_H_ +#define TVM_ARITHMETIC_MODULAR_H_ + +#include +#include "./int_set.h" + +namespace tvm { +namespace arith { + +/*! + * \brief Range of a linear integer function. + * Use to do specify the possible index values. + * + * set = { base + coeff * x | x \in Z } + * + * When coeff != 0, it can also be written as + * set = { n | n % coeff == base } + * + * This is useful to decide if the index is dividable by certain value. + * For example, if index = 0 + 4 x, then we know it can be divided by 4. + */ +struct ModularEntry { + /*! \brief The base */ + int base; + /*! \brief linear co-efficient */ + int coeff; + + /*! \return entry represent everything */ + static ModularEntry everything() { + // always safe to set 0 + x, so it can be everything. + ModularEntry e; + e.base = 0; e.coeff = 1; + return e; + } +}; + +/*! + * \brief Evaluate the expression with modular analysis + * \param e The expression to be evaluated. + * \param mod_map Map of modular statistics of known variables. + * \return The ModularEntry covering all possible value of e. + */ +ModularEntry EvalModular( + const Expr& e, + const std::unordered_map& mod_map); +/*! + * \brief Same as EvalModular, used by front-end. + * \param e The expression to be evaluated. + * \param mod_map Map of modular statistics of known variables. + * \return A ModularSet covering all possible value of e. + */ +IntSet EvalModular(const Expr& e, + const Map& mod_map); +} // namespace arith +} // namespace tvm +#endif // TVM_ARITHMETIC_MODULAR_H_ diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index 07f2b6d21b28..bab7471c0561 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -140,10 +140,6 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) { } } -Stmt IRMutator::Mutate_(const Load *op, const Stmt& s) { - return s; -} - Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) { Expr value = this->Mutate(op->value); Expr index = this->Mutate(op->index); @@ -234,84 +230,24 @@ Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) { } } -#define DEFINE_OP_RETURN_SELF_STMT_MUTATE_(OP) \ - Stmt IRMutator::Mutate_(const OP *op, const Stmt& s) { \ - return s; \ - } - -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Variable) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Let) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Free) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Call) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Add) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Sub) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Mul) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Div) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Mod) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Min) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Max) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(EQ) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(NE) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(LT) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(LE) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(GT) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(GE) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(And) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Or) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Reduce) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Cast) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Not) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Select) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Ramp) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Broadcast) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(IntImm) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(UIntImm) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(FloatImm) -DEFINE_OP_RETURN_SELF_STMT_MUTATE_(StringImm) +Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) { + return s; +} TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) -.DISPATCH_TO_MUTATE_STMT(Variable) .DISPATCH_TO_MUTATE_STMT(LetStmt) .DISPATCH_TO_MUTATE_STMT(AttrStmt) .DISPATCH_TO_MUTATE_STMT(IfThenElse) .DISPATCH_TO_MUTATE_STMT(For) .DISPATCH_TO_MUTATE_STMT(Allocate) -.DISPATCH_TO_MUTATE_STMT(Load) .DISPATCH_TO_MUTATE_STMT(Store) -.DISPATCH_TO_MUTATE_STMT(Let) .DISPATCH_TO_MUTATE_STMT(Free) -.DISPATCH_TO_MUTATE_STMT(Call) -.DISPATCH_TO_MUTATE_STMT(Add) -.DISPATCH_TO_MUTATE_STMT(Sub) -.DISPATCH_TO_MUTATE_STMT(Mul) -.DISPATCH_TO_MUTATE_STMT(Div) -.DISPATCH_TO_MUTATE_STMT(Mod) -.DISPATCH_TO_MUTATE_STMT(Min) -.DISPATCH_TO_MUTATE_STMT(Max) -.DISPATCH_TO_MUTATE_STMT(EQ) -.DISPATCH_TO_MUTATE_STMT(NE) -.DISPATCH_TO_MUTATE_STMT(LT) -.DISPATCH_TO_MUTATE_STMT(LE) -.DISPATCH_TO_MUTATE_STMT(GT) -.DISPATCH_TO_MUTATE_STMT(GE) -.DISPATCH_TO_MUTATE_STMT(And) -.DISPATCH_TO_MUTATE_STMT(Or) -.DISPATCH_TO_MUTATE_STMT(Reduce) -.DISPATCH_TO_MUTATE_STMT(Cast) -.DISPATCH_TO_MUTATE_STMT(Not) -.DISPATCH_TO_MUTATE_STMT(Select) -.DISPATCH_TO_MUTATE_STMT(Ramp) -.DISPATCH_TO_MUTATE_STMT(Broadcast) .DISPATCH_TO_MUTATE_STMT(AssertStmt) .DISPATCH_TO_MUTATE_STMT(ProducerConsumer) .DISPATCH_TO_MUTATE_STMT(Provide) .DISPATCH_TO_MUTATE_STMT(Realize) .DISPATCH_TO_MUTATE_STMT(Block) -.DISPATCH_TO_MUTATE_STMT(Evaluate) -.DISPATCH_TO_MUTATE_STMT(IntImm) -.DISPATCH_TO_MUTATE_STMT(UIntImm) -.DISPATCH_TO_MUTATE_STMT(FloatImm) -.DISPATCH_TO_MUTATE_STMT(StringImm); +.DISPATCH_TO_MUTATE_STMT(Evaluate); // Mutate Expr @@ -450,19 +386,6 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) { return e; \ } -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(LetStmt) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(AttrStmt) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(For) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IfThenElse) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Allocate) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Store) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Free) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(AssertStmt) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(ProducerConsumer) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Provide) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Realize) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Block) -DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Evaluate) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm) @@ -470,15 +393,8 @@ DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm) TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) .DISPATCH_TO_MUTATE_EXPR(Variable) -.DISPATCH_TO_MUTATE_EXPR(LetStmt) -.DISPATCH_TO_MUTATE_EXPR(AttrStmt) -.DISPATCH_TO_MUTATE_EXPR(IfThenElse) -.DISPATCH_TO_MUTATE_EXPR(For) -.DISPATCH_TO_MUTATE_EXPR(Allocate) .DISPATCH_TO_MUTATE_EXPR(Load) -.DISPATCH_TO_MUTATE_EXPR(Store) .DISPATCH_TO_MUTATE_EXPR(Let) -.DISPATCH_TO_MUTATE_EXPR(Free) .DISPATCH_TO_MUTATE_EXPR(Call) .DISPATCH_TO_MUTATE_EXPR(Add) .DISPATCH_TO_MUTATE_EXPR(Sub) @@ -501,12 +417,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) .DISPATCH_TO_MUTATE_EXPR(Select) .DISPATCH_TO_MUTATE_EXPR(Ramp) .DISPATCH_TO_MUTATE_EXPR(Broadcast) -.DISPATCH_TO_MUTATE_EXPR(AssertStmt) -.DISPATCH_TO_MUTATE_EXPR(ProducerConsumer) -.DISPATCH_TO_MUTATE_EXPR(Provide) -.DISPATCH_TO_MUTATE_EXPR(Realize) -.DISPATCH_TO_MUTATE_EXPR(Block) -.DISPATCH_TO_MUTATE_EXPR(Evaluate) .DISPATCH_TO_MUTATE_EXPR(IntImm) .DISPATCH_TO_MUTATE_EXPR(UIntImm) .DISPATCH_TO_MUTATE_EXPR(FloatImm) diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 8e0e68e3f5c2..95c6e017605b 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -2,10 +2,11 @@ #include #include #include +#include TEST(IRF, Basic) { - using namespace Halide::Internal; using namespace tvm; + using namespace tvm::ir; Var x("x"); auto z = x + 1; @@ -21,6 +22,65 @@ TEST(IRF, Basic) { CHECK_EQ(f(z, 2), 4); } +TEST(IRF, ExprTransform) { + using namespace tvm; + using namespace tvm::ir; + Var x("x"); + auto z = x + 1; + + class MyExprFunctor + : public ir::ExprFunctor { + public: + int VisitExpr_(const Variable* op, int b) final { + return b; + } + int VisitExpr_(const IntImm* op, int b) final { + return op->value; + } + int VisitExpr_(const Add* op, int b) final { + return VisitExpr(op->a, b) + VisitExpr(op->b, b); + } + }; + MyExprFunctor f; + CHECK_EQ(f(x, 2), 2); + CHECK_EQ(f(z, 2), 3); + try { + f(z - 1, 2); + LOG(FATAL) << "should fail"; + } catch(dmlc::Error) { + } +} + +TEST(IRF, ExprVisit) { + using namespace tvm; + using namespace tvm::ir; + Var x("x"); + auto z = x + 1; + + class MyVisitor + : public ir::ExprFunctor, + public ir::StmtFunctor { + public: + int count = 0; + // implementation + void VisitExpr_(const Variable* op) final { + ++count; + } + void VisitExpr_(const IntImm* op) final { + } + void VisitExpr_(const Add* op) final { + VisitExpr(op->a); + VisitExpr(op->b); + } + void VisitStmt_(const Evaluate* op) final { + VisitExpr(op->value); + } + }; + MyVisitor v; + v(Evaluate::make(z)); + CHECK_EQ(v.count, 1); +} + int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index b60ed0d510b4..b677ea6ec6fa 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -25,6 +25,11 @@ def test_deduce(): ans1 = (c-b)/4+(-2) assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) + e2 = (tvm.max(5, a * 4) < 0) + res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}) + assert str(res2.max()) == "neg_inf" + assert str(res2.min()) == "pos_inf" + def test_check(): a = tvm.Var('a') b = tvm.Var('b') diff --git a/tests/python/unittest/test_arith_modular.py b/tests/python/unittest/test_arith_modular.py new file mode 100644 index 000000000000..71261d06cffc --- /dev/null +++ b/tests/python/unittest/test_arith_modular.py @@ -0,0 +1,32 @@ +import tvm + +def test_basic(): + a = tvm.Var() + b = tvm.Var() + m = tvm.arith.EvalModular(a * 4 + b * 6 + 7) + assert m.coeff == 2 + assert m.base == 1 + + m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 + 3)) + assert m.coeff == 4 + assert m.base == 3 + + m = tvm.arith.EvalModular((a * 4 + 1) / (b * 8 + 3)) + assert m.coeff == 1 + assert m.base == 0 + + m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 / 4)) + assert m.coeff == 2 + assert m.base == 0 + + m = tvm.arith.EvalModular((a * 12 + 1) - (b * 3 * 7 + 2)) + assert m.coeff == 3 + assert m.base == 2 + + + m = tvm.arith.EvalModular(a * 12 + tvm.min(b * 3 * 7, 2)) + assert m.coeff == 1 + assert m.base == 0 + +if __name__ == "__main__": + test_basic() From e06b09cca1df23ba167b61a7e7276e699b81d96a Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 28 Feb 2017 11:51:34 -0800 Subject: [PATCH 2/7] retrigger --- .travis.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index e7110ecbacac..3b967d8aac2c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -59,7 +59,6 @@ after_failure: - tests/travis/travis_after_failure.sh notifications: -# Emails are sent to the committer's git-configured email address by default, email: on_success: change on_failure: always From 4e0941d832bcb0785ac965bf78a228785b12ae65 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 28 Feb 2017 15:01:22 -0800 Subject: [PATCH 3/7] [IRFunctor] Migrated CodegenC --- src/codegen/codegen_c.cc | 472 ++++++++++----------- src/codegen/codegen_c.h | 102 +++-- src/codegen/codegen_cuda.cc | 4 +- src/codegen/codegen_cuda.h | 2 +- tests/python/unittest/test_codegen_llvm.py | 2 +- 5 files changed, 285 insertions(+), 297 deletions(-) diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 4dbc8efef4f5..00cd4227c7c5 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -67,10 +67,6 @@ std::string CodeGenC::Finish() { return stream.str(); } -void CodeGenC::PrintStmt(const Stmt& n) { - static const FPrintStmt& f = vtable_print_stmt(); - f(n, this); -} std::string CodeGenC::SSAGetID(std::string src, Type t) { if (name_alloc_map_.count(src)) return src; @@ -96,13 +92,12 @@ std::string CodeGenC::SSAGetID(std::string src, Type t) { } void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*) - static const FPrintExpr& f = vtable_print_expr(); if (print_ssa_form_) { std::ostringstream temp; - f(n, temp, this); + VisitExpr(n, temp); os << SSAGetID(temp.str(), n.type()); } else { - f(n, os, this); + VisitExpr(n, os); } } @@ -178,6 +173,102 @@ void CodeGenC::MarkConst(std::string vid) { } } +int CodeGenC::BeginScope() { + int sid = static_cast(scope_mark_.size()); + scope_mark_.push_back(true); + indent += 2; + return sid; +} + +void CodeGenC::EndScope(int scope_id) { + scope_mark_[scope_id] = false; + indent -= 2; +} + +// Print a reference expression to a buffer. +void CodeGenC::PrintBufferRef( + const Variable* buffer, + Type t, Expr index, + std::ostream& os) { // NOLINT(*) + std::string vid = GetVarID(buffer); + if (t.lanes() == 1) { + if (!HandleTypeMatch(buffer, t)) { + os << "(("; + PrintType(t, os); + os << "*)" << vid << ')'; + } else { + os << vid; + } + os << '['; + PrintExpr(index, os); + os << ']'; + } else { + // Buffer declared as vector type. + // optimize for case where it is in register, + if (HandleTypeMatch(buffer, t)) { + // optimize for constant access + int offset; + if (arith::GetConstInt(index, &offset)) { + CHECK_EQ(offset % t.lanes(), 0) + << "Find unaligned vector load to a vector type"; + os << vid << '[' << (offset / t.lanes()) << ']'; + return; + } + } + os << "(("; + PrintType(t, os); + os << "*)("; + if (!HandleTypeMatch(buffer, t.element_of())) { + os << '('; + PrintType(t.element_of(), os); + os << "*)"; + } + os << vid << " + "; + PrintExpr(index, os); + os << "))[0]"; + } +} + +void CodeGenC::PrintVecElemLoad(const std::string& vec, + Type t, int i, + std::ostream& os) { // NOLINT(*) + os << vec << ".s" << std::hex << i; +} + +void CodeGenC::PrintVecElemStore(const std::string& vec, + Type t, int i, + const std::string& value) { + this->PrintIndent(); + stream << vec << ".s" << std::hex << i + << " = " << value << ";\n"; +} + +void CodeGenC::PrintVecLoad(const Variable* buffer, + Type t, Expr base, + std::ostream& os) { + PrintBufferRef(buffer, t, base, os); +} + +void CodeGenC::PrintVecStore(const Variable* buffer, + Type t, Expr base, + const std::string& value) { + this->PrintIndent(); + PrintBufferRef(buffer, t, base, stream); + stream << " = " << value << ";\n"; +} + +void CodeGenC::PrintThreadIndexExpr( + std::string thread_tag, std::ostream& os) { // NOLINT(*) + os << thread_tag; +} + +void CodeGenC::PrintStorageSync(const std::string& sync) { // NOLINT(*) +} + +void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) + CHECK_EQ(scope, "global"); +} + void CodeGenC::PrintType(Type t, std::ostream& os) const { // NOLINT(*) CHECK_EQ(t.lanes(), 1) << "do not yet support vector types"; @@ -208,13 +299,6 @@ void CodeGenC::PrintType(Type t, std::ostream& os) const { // NOLINT(*) LOG(FATAL) << "Cannot convert type " << t << " to C type"; } -CodeGenC::FPrintStmt& CodeGenC::vtable_print_stmt() { // NOLINT(*) - static FPrintStmt inst; return inst; -} - -CodeGenC::FPrintExpr& CodeGenC::vtable_print_expr() { // NOLINT(*) - static FPrintExpr inst; return inst; -} inline void PrintConst(const IntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) if (op->type == Int(32)) { @@ -262,19 +346,18 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N } } -TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) -.set_dispatch([](const IntImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*) - PrintConst(op, os, p); - }) -.set_dispatch([](const UIntImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*) - PrintConst(op, os, p); - }) -.set_dispatch([](const FloatImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*) - PrintConst(op, os, p); - }) -.set_dispatch([](const StringImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*) - os << "\"" << op->value << "\""; - }); +void CodeGenC::VisitExpr_(const IntImm *op, std::ostream& os) { // NOLINT(*) + PrintConst(op, os, this); +} +void CodeGenC::VisitExpr_(const UIntImm *op, std::ostream& os) { // NOLINT(*) + PrintConst(op, os, this); +} +void CodeGenC::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*) + PrintConst(op, os, this); +} +void CodeGenC::VisitExpr_(const StringImm *op, std::ostream& os) { // NOLINT(*) + os << "\"" << op->value << "\""; +} template inline void PrintBinaryExpr(const T* op, @@ -315,137 +398,99 @@ inline void PrintBinaryIntrinsitc(const Call* op, p->PrintVecBinaryOp(opstr, op->type, op->args[0], op->args[1], os); } } +void CodeGenC::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*) + this->PrintType(op->type, os); + os << '('; + this->PrintExpr(op->value, os); + os << ')'; +} +void CodeGenC::VisitExpr_(const Variable *op, std::ostream& os) { // NOLINT(*) + os << GetVarID(op); +} +void CodeGenC::VisitExpr_(const Add *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "+", os, this); +} +void CodeGenC::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "-", os, this); +} +void CodeGenC::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "*", os, this); +} +void CodeGenC::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "/", os, this); +} +void CodeGenC::VisitExpr_(const Mod *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "%", os, this); +} +void CodeGenC::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "min", os, this); +} +void CodeGenC::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "max", os, this); +} +void CodeGenC::VisitExpr_(const EQ *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "==", os, this); +} +void CodeGenC::VisitExpr_(const NE *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "!=", os, this); +} +void CodeGenC::VisitExpr_(const LT *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "<", os, this); +} +void CodeGenC::VisitExpr_(const LE *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "<=", os, this); +} +void CodeGenC::VisitExpr_(const GT *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, ">", os, this); +} +void CodeGenC::VisitExpr_(const GE *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, ">=", os, this); +} +void CodeGenC::VisitExpr_(const And *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "&&", os, this); +} +void CodeGenC::VisitExpr_(const Or *op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "||", os, this); +} +void CodeGenC::VisitExpr_(const Not *op, std::ostream& os) { // NOLINT(*) + os << '!'; + PrintExpr(op->a, os); +} -TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) -.set_dispatch([](const Cast *op, std::ostream& os, CodeGenC *p) { // NOLINT(*) - p->PrintType(op->type, os); - os << '('; - p->PrintExpr(op->value, os); - os << ')'; - }) -.set_dispatch([](const Variable *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - os << p->GetVarID(op); - }) -.set_dispatch([](const Add *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, "+", os, p); - }) -.set_dispatch([](const Sub *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, "-", os, p); - }) -.set_dispatch([](const Mul *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, "*", os, p); - }) -.set_dispatch
([](const Div *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, "/", os, p); - }) -.set_dispatch([](const Mod *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, "%", os, p); -}) -.set_dispatch([](const Min *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, "min", os, p); -}) -.set_dispatch([](const Max *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, "max", os, p); -}) -.set_dispatch([](const EQ *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, "==", os, p); -}) -.set_dispatch([](const NE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, "!=", os, p); -}) -.set_dispatch([](const LT *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, "<", os, p); -}) -.set_dispatch([](const LE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, "<=", os, p); -}) -.set_dispatch([](const GT *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, ">", os, p); -}) -.set_dispatch([](const GE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, ">=", os, p); -}) -.set_dispatch([](const And *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, "&&", os, p); -}) -.set_dispatch([](const Or *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - PrintBinaryExpr(op, "||", os, p); -}) -.set_dispatch([](const Not *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) - os << '!'; - p->PrintExpr(op->a, os); - }); - -TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt) -.set_dispatch([](const ProducerConsumer *op, CodeGenC* p) { - p->PrintStmt(op->body); - }) -.set_dispatch([](const Block *op, CodeGenC* p) { - p->PrintStmt(op->first); - if (op->rest.defined()) p->PrintStmt(op->rest); - }) -.set_dispatch([](const Evaluate *op, CodeGenC* p) { - if (is_const(op->value)) return; - const Call* call = op->value.as(); - - if (call && call->is_intrinsic(intrinsic::tvm_storage_sync)) { - p->PrintStorageSync(call->args[0].as()->value); - } else { - std::string vid = p->PrintExpr(op->value); - p->PrintIndent(); - p->stream << "(void)" << vid << ";\n"; - } - }); - - -#define DISPATCH_EXPR(OP) \ - set_dispatch([](const OP *op, std::ostream&os, CodeGenC* p) { \ - p->PrintExpr(op, os); }) - -TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) -.DISPATCH_EXPR(Load) -.DISPATCH_EXPR(Call) -.DISPATCH_EXPR(Let) -.DISPATCH_EXPR(Ramp) -.DISPATCH_EXPR(Broadcast) -.DISPATCH_EXPR(Select); - - -void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*) - CodeGenC* p = this; +void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) if (op->is_intrinsic(Call::bitwise_and)) { - PrintBinaryIntrinsitc(op, " & ", os, p); + PrintBinaryIntrinsitc(op, " & ", os, this); } else if (op->is_intrinsic(Call::bitwise_xor)) { - PrintBinaryIntrinsitc(op, " ^ ", os, p); + PrintBinaryIntrinsitc(op, " ^ ", os, this); } else if (op->is_intrinsic(Call::bitwise_or)) { - PrintBinaryIntrinsitc(op, " | ", os, p); + PrintBinaryIntrinsitc(op, " | ", os, this); } else if (op->is_intrinsic(Call::bitwise_not)) { CHECK_EQ(op->args.size(), 1U); os << "(~"; - p->PrintExpr(op->args[0], os); + this->PrintExpr(op->args[0], os); os << ')'; } else if (op->is_intrinsic(Call::shift_left)) { - PrintBinaryIntrinsitc(op, " << ", os, p); + PrintBinaryIntrinsitc(op, " << ", os, this); } else if (op->is_intrinsic(Call::shift_right)) { - PrintBinaryIntrinsitc(op, " >> ", os, p); + PrintBinaryIntrinsitc(op, " >> ", os, this); } else if (op->is_intrinsic(Call::address_of)) { const Load *l = op->args[0].as(); CHECK(op->args.size() == 1 && l); os << "(("; - p->PrintType(l->type.element_of(), os); - os << " *)" << p->GetVarID(l->buffer_var.get()) + this->PrintType(l->type.element_of(), os); + os << " *)" << this->GetVarID(l->buffer_var.get()) << " + "; - p->PrintExpr(l->index, os); + this->PrintExpr(l->index, os); os << ')'; } else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) { CHECK_EQ(op->args.size(), 3U); if (!op->type.is_handle()) { os << '('; - p->PrintType(op->type, os); + this->PrintType(op->type, os); os << ')'; } os << "(((TVMArg*)"; - p->PrintExpr(op->args[0], os); + this->PrintExpr(op->args[0], os); os << ")[" << op->args[2] << "]."; if (op->type.is_handle()) { os << "v_handle"; @@ -460,7 +505,7 @@ void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*) } else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) { CHECK_EQ(op->args.size(), 2U); os << "(((TVMArray*)"; - p->PrintExpr(op->args[0], os); + this->PrintExpr(op->args[0], os); os << ")->"; switch (op->args[1].as()->value) { case intrinsic::kData: os << "data"; break; @@ -476,12 +521,12 @@ void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*) } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { CHECK_EQ(op->args.size(), 1U); os << "("; - p->PrintExpr(op->args[0], os); + this->PrintExpr(op->args[0], os); os << " == NULL)"; } else { os << op->name << "("; for (size_t i = 0; i < op->args.size(); i++) { - p->PrintExpr(op->args[i], os); + this->PrintExpr(op->args[i], os); if (i < op->args.size() - 1) { os << ", "; } @@ -517,51 +562,7 @@ inline bool TryGetRamp1Base(Expr index, int lanes, Expr *base) { return true; } -// Print a reference expression to a buffer. -void CodeGenC::PrintBufferRef( - const Variable* buffer, - Type t, Expr index, - std::ostream& os) { // NOLINT(*) - std::string vid = GetVarID(buffer); - if (t.lanes() == 1) { - if (!HandleTypeMatch(buffer, t)) { - os << "(("; - PrintType(t, os); - os << "*)" << vid << ')'; - } else { - os << vid; - } - os << '['; - PrintExpr(index, os); - os << ']'; - } else { - // Buffer declared as vector type. - // optimize for case where it is in register, - if (HandleTypeMatch(buffer, t)) { - // optimize for constant access - int offset; - if (arith::GetConstInt(index, &offset)) { - CHECK_EQ(offset % t.lanes(), 0) - << "Find unaligned vector load to a vector type"; - os << vid << '[' << (offset / t.lanes()) << ']'; - return; - } - } - os << "(("; - PrintType(t, os); - os << "*)("; - if (!HandleTypeMatch(buffer, t.element_of())) { - os << '('; - PrintType(t.element_of(), os); - os << "*)"; - } - os << vid << " + "; - PrintExpr(index, os); - os << "))[0]"; - } -} - -void CodeGenC::PrintExpr(const Load* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) int lanes = op->type.lanes(); if (op->type.lanes() == 1) { this->PrintBufferRef(op->buffer_var.get(), op->type, op->index, os); @@ -600,7 +601,7 @@ void CodeGenC::PrintExpr(const Load* op, std::ostream& os) { // NOLINT(*) } } -void CodeGenC::PrintStmt(const Store* op) { +void CodeGenC::VisitStmt_(const Store* op) { Type t = op->value.type(); if (t.lanes() == 1) { this->PrintIndent(); @@ -637,35 +638,7 @@ void CodeGenC::PrintStmt(const Store* op) { } } -void CodeGenC::PrintVecElemLoad(const std::string& vec, - Type t, int i, - std::ostream& os) { // NOLINT(*) - os << vec << ".s" << std::hex << i; -} - -void CodeGenC::PrintVecElemStore(const std::string& vec, - Type t, int i, - const std::string& value) { - this->PrintIndent(); - stream << vec << ".s" << std::hex << i - << " = " << value << ";\n"; -} - -void CodeGenC::PrintVecLoad(const Variable* buffer, - Type t, Expr base, - std::ostream& os) { - PrintBufferRef(buffer, t, base, os); -} - -void CodeGenC::PrintVecStore(const Variable* buffer, - Type t, Expr base, - const std::string& value) { - this->PrintIndent(); - PrintBufferRef(buffer, t, base, stream); - stream << " = " << value << ";\n"; -} - -void CodeGenC::PrintExpr(const Let* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*) CHECK(print_ssa_form_) << "LetExpr is only supported by print SSA form"; std::string value = PrintExpr(op->value); @@ -673,41 +646,19 @@ void CodeGenC::PrintExpr(const Let* op, std::ostream& os) { // NOLINT(*) var_idmap_[op->var.get()] = value; } -void CodeGenC::PrintExpr(const Ramp* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Ramp: not supported "; } -void CodeGenC::PrintExpr(const Broadcast* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Broadcast: not supported "; } -void CodeGenC::PrintExpr(const Select* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Select: not supported "; } -// Disoatch back to member functions -TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt) -.set_dispatch([](const LetStmt *op, CodeGenC* p) { p->PrintStmt(op); }) -.set_dispatch([](const Store *op, CodeGenC* p) { p->PrintStmt(op); }) -.set_dispatch([](const Allocate *op, CodeGenC* p) { p->PrintStmt(op); }) -.set_dispatch([](const AttrStmt *op, CodeGenC* p) { p->PrintStmt(op); }) -.set_dispatch([](const AssertStmt *op, CodeGenC* p) { p->PrintStmt(op); }) -.set_dispatch([](const For *op, CodeGenC* p) { p->PrintStmt(op); }) -.set_dispatch([](const IfThenElse *op, CodeGenC* p) { p->PrintStmt(op); }); - -void CodeGenC::PrintThreadIndexExpr( - std::string thread_tag, std::ostream& os) { // NOLINT(*) - os << thread_tag; -} - -void CodeGenC::PrintStorageSync(const std::string& sync) { // NOLINT(*) -} - -void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) - CHECK_EQ(scope, "global"); -} - -void CodeGenC::PrintStmt(const LetStmt* op) { +void CodeGenC::VisitStmt_(const LetStmt* op) { std::string value = PrintExpr(op->value); if (print_ssa_form_) { CHECK(!var_idmap_.count(op->var.get())); @@ -732,7 +683,7 @@ void CodeGenC::PrintStmt(const LetStmt* op) { PrintStmt(op->body); } -void CodeGenC::PrintStmt(const Allocate* op) { +void CodeGenC::VisitStmt_(const Allocate* op) { CHECK(!is_zero(op->condition)); std::string vid = AllocVarID(op->buffer_var.get()); if (op->new_expr.defined()) { @@ -758,7 +709,7 @@ void CodeGenC::PrintStmt(const Allocate* op) { this->PrintStmt(op->body); } -void CodeGenC::PrintStmt(const AttrStmt* op) { +void CodeGenC::VisitStmt_(const AttrStmt* op) { if (op->type_key == ir::attr::thread_extent) { IterVar iv(op->node.node_); if (iv->thread_tag.length() != 0) { @@ -780,7 +731,7 @@ void CodeGenC::PrintStmt(const AttrStmt* op) { this->PrintStmt(op->body); } -void CodeGenC::PrintStmt(const AssertStmt* op) { +void CodeGenC::VisitStmt_(const AssertStmt* op) { std::string cond = PrintExpr(op->condition); PrintIndent(); if (op->message.as()) { @@ -792,19 +743,7 @@ void CodeGenC::PrintStmt(const AssertStmt* op) { } } -int CodeGenC::BeginScope() { - int sid = static_cast(scope_mark_.size()); - scope_mark_.push_back(true); - indent += 2; - return sid; -} - -void CodeGenC::EndScope(int scope_id) { - scope_mark_[scope_id] = false; - indent -= 2; -} - -void CodeGenC::PrintStmt(const For* op) { +void CodeGenC::VisitStmt_(const For* op) { std::string extent = PrintExpr(op->extent); PrintIndent(); std::string vid = AllocVarID(op->loop_var.get()); @@ -821,7 +760,7 @@ void CodeGenC::PrintStmt(const For* op) { stream << "}\n"; } -void CodeGenC::PrintStmt(const IfThenElse* op) { +void CodeGenC::VisitStmt_(const IfThenElse* op) { std::string cond = PrintExpr(op->condition); PrintIndent(); stream << "if (" << cond << ") {\n"; @@ -840,6 +779,27 @@ void CodeGenC::PrintStmt(const IfThenElse* op) { stream << "}\n"; } +void CodeGenC::VisitStmt_(const Block *op) { + PrintStmt(op->first); + if (op->rest.defined()) PrintStmt(op->rest); +} + +void CodeGenC::VisitStmt_(const Evaluate *op) { + if (is_const(op->value)) return; + const Call* call = op->value.as(); + + if (call && call->is_intrinsic(intrinsic::tvm_storage_sync)) { + this->PrintStorageSync(call->args[0].as()->value); + } else { + std::string vid = this->PrintExpr(op->value); + this->PrintIndent(); + this->stream << "(void)" << vid << ";\n"; + } +} + +void CodeGenC::VisitStmt_(const ProducerConsumer *op) { + PrintStmt(op->body); +} } // namespace codegen } // namespace tvm diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index 1c87c9164034..c2fa60423d17 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -7,6 +7,7 @@ #define TVM_CODEGEN_CODEGEN_C_H_ #include +#include #include #include #include @@ -16,12 +17,15 @@ namespace tvm { namespace codegen { +using namespace ir; /*! * \brief A base class to generate C code. * * CodeGenC have two modes: generate SSA formed C code or normal form. */ -class CodeGenC { +class CodeGenC : + public StmtFunctor, + public ExprFunctor { public: /*! * \brief Initialize the code generator. @@ -42,13 +46,15 @@ class CodeGenC { * \brief Print the Stmt n to CodeGenC->stream * \param n The statement to be printed. */ - void PrintStmt(const Stmt& n); + void PrintStmt(const Stmt& n) { + VisitStmt(n); + } /*! * \brief Print the expression n(or its ssa id if in ssa mode) into os * \param n The expression to be printed. * \param os The output stream */ - void PrintExpr(const Expr& n, std::ostream& os); // NOLINT(*) + void PrintExpr(const Expr& n, std::ostream& os); /*! * \brief Same as PrintExpr, but simply returns result string * \param n The expression to be printed. @@ -84,6 +90,46 @@ class CodeGenC { * \param f The function to be compiled. */ virtual void InitFuncState(LoweredFunc f); + // expression + void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LE* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const GT* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const Broadcast* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const IntImm* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const UIntImm* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloatImm* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const StringImm* op, std::ostream& os) override; // NOLINT(*) + // statment + void VisitStmt_(const LetStmt* op) override; + void VisitStmt_(const Store* op) override; + void VisitStmt_(const For* op) override; + void VisitStmt_(const IfThenElse* op) override; + void VisitStmt_(const Allocate* op) override; + void VisitStmt_(const AttrStmt* op) override; + void VisitStmt_(const AssertStmt* op) override; + void VisitStmt_(const Evaluate* op) override; + void VisitStmt_(const Block* op) override; + void VisitStmt_(const ProducerConsumer* op) override; /*! * Print Type represetnation of type t. * \param t The type representation. @@ -97,50 +143,37 @@ class CodeGenC { */ virtual void PrintThreadIndexExpr( std::string tag, std::ostream& os); // NOLINT(*) - virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(* + virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*) virtual void PrintStorageSync(const std::string& scope); // NOLINT(*) - - virtual void PrintStmt(const ir::LetStmt* op); - virtual void PrintStmt(const ir::Store* op); - virtual void PrintStmt(const ir::For* op); - virtual void PrintStmt(const ir::IfThenElse* op); - virtual void PrintStmt(const ir::Allocate* op); - virtual void PrintStmt(const ir::AttrStmt* op); - virtual void PrintStmt(const ir::AssertStmt* op); - virtual void PrintExpr(const ir::Load* op, std::ostream& os); // NOLINT(*) - virtual void PrintExpr(const ir::Call* op, std::ostream& os); // NOLINT(*) - virtual void PrintExpr(const ir::Let* op, std::ostream& os); // NOLINT(*) - virtual void PrintExpr(const ir::Ramp* op, std::ostream& os); // NOLINT(*) - virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os); // NOLINT(*) - virtual void PrintExpr(const ir::Select* op, std::ostream& os); // NOLINT(*) // Binary vector op. virtual void PrintVecBinaryOp( const std::string&op, Type op_type, Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*) + // print vector load virtual void PrintVecLoad(const Variable* buffer, Type t, Expr base, std::ostream& os); // NOLINT(*) + // print vector store virtual void PrintVecStore(const Variable* buffer, Type t, Expr base, const std::string& value); // NOLINT(*) + // print load of single element virtual void PrintVecElemLoad( const std::string& vec, Type t, int i, std::ostream& os); // NOLINT(*) + // print store of single element. virtual void PrintVecElemStore( const std::string& vec, Type t, int i, const std::string& value); - /*! \brief function print into the ostream */ - using FPrintExpr = IRFunctor; // NOLINT(*) - /*! \brief function to to print normal code */ - using FPrintStmt = IRFunctor; - // vtable to print code - static FPrintStmt& vtable_print_stmt(); - // vtable to print code - static FPrintExpr& vtable_print_expr(); - /*! \brief The current indentation value */ - int indent{0}; - /*! \brief the stream to be printed */ - std::ostringstream stream; protected: + /*! \brief the stream to be printed */ + std::ostringstream stream; + /*! \brief entry in ssa assign map */ + struct SSAEntry { + /*! \brief The value id */ + std::string vid; + /*! \brief The scope id */ + int scope_id; + }; // print reference to a buffer as type t in index. void PrintBufferRef(const Variable* buffer, Type t, Expr index, @@ -158,13 +191,6 @@ class CodeGenC { * \return The returned name. */ std::string GetUniqueName(std::string prefix); - /*! \brief entry in ssa assign map */ - struct SSAEntry { - /*! \brief The value id */ - std::string vid; - /*! \brief The scope id */ - int scope_id; - }; /*! * \brief mark the beginning of a new scope * \return The scope id. @@ -209,6 +235,8 @@ class CodeGenC { std::unordered_map handle_data_type_; /*! \brief array to check whether we are inside certain scope */ std::vector scope_mark_; + /*! \brief The current indentation value */ + int indent{0}; }; } // namespace codegen diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 8112ee7b1d8c..0c12967b93e0 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -19,7 +19,7 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) { CodeGenC::AddFunction(f); } -void CodeGenCUDA::PrintStmt(const ir::For* op) { +void CodeGenCUDA::VisitStmt_(const ir::For* op) { int ext; CHECK(is_zero(op->min)); if (arith::GetConstInt(op->extent, &ext) && @@ -27,7 +27,7 @@ void CodeGenCUDA::PrintStmt(const ir::For* op) { PrintIndent(); stream << "#pragma unroll\n"; } - CodeGenC::PrintStmt(op); + CodeGenC::VisitStmt_(op); } void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*) diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index 4519e779fa8b..d6556ba2c106 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -18,7 +18,7 @@ class CodeGenCUDA : public CodeGenC { public: void AddFunction(LoweredFunc f); // override behavior - void PrintStmt(const ir::For* op) final; + void VisitStmt_(const ir::For* op) final; void PrintStorageSync(const std::string& sync) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintVecBinaryOp( diff --git a/tests/python/unittest/test_codegen_llvm.py b/tests/python/unittest/test_codegen_llvm.py index fed6cb6f283c..bf2a02826055 100644 --- a/tests/python/unittest/test_codegen_llvm.py +++ b/tests/python/unittest/test_codegen_llvm.py @@ -16,7 +16,7 @@ def check_llvm(): f = tvm.build(s, [A, B, C], "llvm") ctx = tvm.cpu(0) # launch the kernel. - n = 10270 * 2460 + n = 1027 * 1024 a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) From 91267947105fc1919ecdd1345de6d81969da691a Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 28 Feb 2017 15:13:40 -0800 Subject: [PATCH 4/7] [IRFUNCTOR] Migrate CodeGenLLVM --- src/codegen/codegen_c.h | 4 +- src/codegen/llvm/codegen_llvm.cc | 1229 +++++++++++++++--------------- src/codegen/llvm/codegen_llvm.h | 86 +-- 3 files changed, 667 insertions(+), 652 deletions(-) diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index c2fa60423d17..9068d0223b8e 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -24,8 +24,8 @@ using namespace ir; * CodeGenC have two modes: generate SSA formed C code or normal form. */ class CodeGenC : - public StmtFunctor, - public ExprFunctor { + public ExprFunctor, + public StmtFunctor { public: /*! * \brief Initialize the code generator. diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index fbe455c691ea..d7c0a35d3b79 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -130,7 +130,7 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) { llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_); builder_->SetInsertPoint(block); - this->Visit(f->body); + this->VisitStmt(f->body); builder_->CreateRet(ConstInt32(0)); } @@ -222,240 +222,369 @@ llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const { return ret; } -void CodeGenLLVM::Visit_(const Variable* op) { - value_ = GetVarValue(op); +llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess(llvm::Value* retcode) { + // create emit codes that checks and load the function. + using llvm::BasicBlock; + BasicBlock* fail_block = BasicBlock::Create( + *ctx_, "call_fail", function_); + BasicBlock* end_block = BasicBlock::Create( + *ctx_, "call_end", function_); + llvm::Value* succ = builder_->CreateICmpEQ( + retcode, llvm::ConstantInt::get(t_int_, 0)); + builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_); + builder_->SetInsertPoint(fail_block); + // return the code. + builder_->CreateRet(retcode); + // otherwise set it to be new end. + builder_->SetInsertPoint(end_block); + return end_block; } -void CodeGenLLVM::Visit_(const Cast* op) { - value_ = CreateCast(op->value.type(), op->type, MakeValue(op->value)); -} +void CodeGenLLVM::AddAliasInfo( + llvm::Instruction* inst, const Variable* buffer, Expr index) { + int base = 0, width = 0; + // create meta-data for alias analysis + // Use a group of binary tree ranges. + const Ramp* ramp = index.as(); + if (ramp) { + int base, stride; + if (arith::GetConstInt(ramp->base, &base) && + arith::GetConstInt(ramp->stride, &stride)) { + int xwith = ramp->lanes * stride; + width = 1; + while (width < xwith) { + width *= 2; + } + while (base % width) { + base -= base % width; + width *= 2; + } + } + } else { + if (arith::GetConstInt(index, &base)) width = 1; + } -void CodeGenLLVM::Visit_(const IntImm* op) { - value_ = llvm::ConstantInt::getSigned(LLVMType(op->type), op->value); + llvm::MDNode* meta = md_tbaa_root_; + std::ostringstream buffer_addr; + buffer_addr << buffer; + meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta); + // create a tree-shape access structure. + if (width != 0) { + for (int w = 1024; w >= width; w /= 2) { + int b = (base / w) * w; + std::stringstream os; + os << buffer << ".w" << w << ".b" << b; + meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta); + } + } + inst->setMetadata( + "tbaa", + md_builder_->createTBAAStructTagNode(meta, meta, 0)); } -void CodeGenLLVM::Visit_(const UIntImm* op) { - value_ = llvm::ConstantInt::get(LLVMType(op->type), op->value); +llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { + llvm::Constant* init = llvm::UndefValue::get( + llvm::VectorType::get(value->getType(), lanes)); + llvm::Constant* zero = ConstInt32(0); + value = builder_->CreateInsertElement(init, value, zero); + llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero); + return builder_->CreateShuffleVector(value, init, mask); } -void CodeGenLLVM::Visit_(const FloatImm* op) { - value_ = llvm::ConstantFP::get(LLVMType(op->type), op->value); -} +llvm::Value* CodeGenLLVM::CreateBufferPtr( + Type t, llvm::Value* buffer, llvm::Value* index) { + llvm::Type* elem_type = buffer->getType(); + unsigned address_space = elem_type->getPointerAddressSpace(); + llvm::Type* load_type = LLVMType(t)->getPointerTo(address_space); -void CodeGenLLVM::Visit_(const StringImm* op) { - value_ = GetConstString(op->value); + if (load_type != elem_type) { + buffer = builder_->CreatePointerCast(buffer, load_type); + } + llvm::Constant* cindex = llvm::dyn_cast(index); + if (cindex && cindex->isZeroValue()) { + return buffer; + } + return builder_->CreateInBoundsGEP(buffer, index); } -#define DEFINE_CODEGEN_BINARY_OP(OP) \ - llvm::Value* CodeGenLLVM::Create ## OP( \ - Type t, llvm::Value* a, llvm::Value *b) { \ - if (t.is_float()) { \ - return builder_->CreateF ## OP (a, b); \ - } else if (t.is_int() && t.bits() >= 32) { \ - return builder_->CreateNSW ## OP (a, b); \ - } else { \ - return builder_->Create ## OP (a, b); \ - } \ - } \ - -DEFINE_CODEGEN_BINARY_OP(Add); -DEFINE_CODEGEN_BINARY_OP(Sub); -DEFINE_CODEGEN_BINARY_OP(Mul); - -void CodeGenLLVM::Visit_(const Add* op) { - value_ = CreateAdd(op->type, MakeValue(op->a), MakeValue(op->b)); +llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) { + llvm::Type * target = LLVMType(to); + if (value->getType() == target) return value; + if (from.is_handle() && from.is_handle()) { + return builder_->CreateBitCast(value, target); + } else if (!from.is_float() && !to.is_float()) { + return builder_->CreateIntCast(value, target, from.is_int()); + } else if (from.is_float() && to.is_int()) { + return builder_->CreateFPToSI(value, target); + } else if (from.is_float() && to.is_uint()) { + if (to.bits() < 8) { + value = builder_->CreateFPToUI(value, LLVMType(to.with_bits(8))); + return builder_->CreateIntCast(value, target, false); + } else { + return builder_->CreateFPToUI(value, target); + } + } else if (from.is_int() && to.is_float()) { + return builder_->CreateSIToFP(value, target); + } else if (from.is_uint() && to.is_float()) { + return builder_->CreateUIToFP(value, target); + } else { + CHECK(from.is_float() && to.is_float()); + return builder_->CreateFPCast(value, target); + } } -void CodeGenLLVM::Visit_(const Sub* op) { - value_ = CreateSub(op->type, MakeValue(op->a), MakeValue(op->b)); +llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) { + using llvm::BasicBlock; + // We will store the packed function handle in global space. + // Initialize it during the first call. + llvm::DataLayout layout(module_.get()); + uint64_t align = layout.getTypeAllocSize(t_tvm_func_handle_); + auto it = func_handle_map_.find(fname); + + llvm::GlobalVariable* hptr; + if (it == func_handle_map_.end()) { + // create global location for the handle + // create the function handle + hptr = new llvm::GlobalVariable( + *module_, t_tvm_func_handle_, false, + llvm::GlobalValue::PrivateLinkage, 0, ".tvm_func"); + hptr->setAlignment(align); + hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_)); + func_handle_map_[fname] = hptr; + } else { + hptr = it->second; + } + // create emit codes that checks and load the function. + BasicBlock* pre_block = builder_->GetInsertBlock(); + BasicBlock* init_block = BasicBlock::Create( + *ctx_, "handle_init", function_); + BasicBlock* end_block = BasicBlock::Create( + *ctx_, "handle_init_end", function_); + llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align); + llvm::Value* handle_not_null = builder_->CreateICmpNE( + handle, llvm::Constant::getNullValue(t_tvm_func_handle_)); + builder_->CreateCondBr( + handle_not_null, end_block, init_block, md_very_likely_branch_); + // Initialize the handle if needed. + builder_->SetInsertPoint(init_block); + llvm::Value* out = builder_->CreateAlloca(t_tvm_func_handle_); + llvm::Value* ctx = builder_->CreateLoad(gv_mod_ctx_); + llvm::Value* retcode = builder_->CreateCall( + f_tvm_get_func_from_env_, {ctx, GetConstString(fname), out}); + init_block = CheckCallSuccess(retcode); + llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align); + builder_->CreateBr(end_block); + // end block + builder_->SetInsertPoint(end_block); + llvm::PHINode* phi = builder_->CreatePHI(t_tvm_func_handle_, 2); + phi->addIncoming(handle, pre_block); + phi->addIncoming(loaded_handle, init_block); + return phi; } -void CodeGenLLVM::Visit_(const Mul* op) { - value_ = CreateMul(op->type, MakeValue(op->a), MakeValue(op->b)); +llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) { + CHECK_GE(op->args.size(), 1U); + std::string func_name = op->args[0].as()->value; + llvm::Value* handle = GetPackedFuncHandle(func_name); + // call the function + unsigned nargs = static_cast(op->args.size() - 1); + llvm::Value* targs = builder_->CreateAlloca( + t_tvm_value_, ConstInt32(nargs)); + llvm::Value* tcodes = builder_->CreateAlloca( + t_int_, ConstInt32(nargs)); + for (unsigned i = 0; i < nargs; ++i) { + Expr expr = op->args[i + 1]; + Type t = expr.type(); + CHECK_EQ(t.lanes(), 1); + // Always pass via 64 bit value. + // For handle type, Handle(64) maps to 32 bit void* in 32bit platform. + Type api_type = t.with_bits(64); + llvm::Value* value = CreateCast(t, api_type, MakeValue(expr)); + llvm::Value* store_ptr = builder_->CreatePointerCast( + builder_->CreateInBoundsGEP(targs, ConstInt32(i)), + LLVMType(api_type)->getPointerTo()); + builder_->CreateAlignedStore(value, store_ptr, 8); + builder_->CreateAlignedStore( + ConstInt32(t.code()), + builder_->CreateInBoundsGEP(tcodes, ConstInt32(i)), 4); + } + llvm::Value* ret_value = builder_->CreateAlloca(t_tvm_value_); + llvm::Value* ret_tcode = builder_->CreateAlloca(t_int_); + CheckCallSuccess( + builder_->CreateCall( + f_tvm_func_call_, + {handle, targs, tcodes, ConstInt32(nargs), ret_value, ret_tcode})); + Type r_type = op->type; + Type r_api_type = op->type.with_bits(64); + llvm::Value* rvalue = + builder_->CreateAlignedLoad( + builder_->CreatePointerCast( + ret_value, LLVMType(r_api_type)->getPointerTo()), 8); + rvalue = CreateCast(r_api_type, r_type, rvalue); + return rvalue; } -void CodeGenLLVM::Visit_(const Div* op) { - llvm::Value* a = MakeValue(op->a); - int shift; - if (op->type.is_float()) { - value_ = builder_->CreateFDiv(a, MakeValue(op->b)); - } else if ((op->type.is_int() || op->type.is_uint()) && - is_const_power_of_two_integer(op->b, &shift)) { - value_ = builder_->CreateAShr(a, shift); +llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) { + std::vector arg_values(op->args.size()); + for (size_t i = 0; i < op->args.size(); ++i) { + arg_values[i] = MakeValue(op->args[i]); + } + if (op->type.is_scalar()) { + llvm::Function* f = module_->getFunction(op->name); + if (f) { + return builder_->CreateCall(f, arg_values); + } else { + LOG(FATAL) << "cannot find function " << op->name; + } } else { - llvm::Value* b = MakeValue(op->b); - if (op->type.is_int()) { - value_ = builder_->CreateSDiv(a, b); + llvm::Function* f = module_->getFunction(op->name); + if (f) { + return CreateScalarizedCall(op, f, arg_values); } else { - CHECK(op->type.is_uint()); - value_ = builder_->CreateUDiv(a, b); + LOG(FATAL) << "cannot find function " << op->name; } } + return nullptr; } -void CodeGenLLVM::Visit_(const Mod* op) { - CHECK(!op->type.is_float()) - << "Cannot do mod for float"; - if (op->type.is_int()) { - value_ = builder_->CreateSRem(MakeValue(op->a), MakeValue(op->b)); - } else { - CHECK(op->type.is_uint()); - value_ = builder_->CreateURem(MakeValue(op->a), MakeValue(op->b)); +llvm::Value* CodeGenLLVM::CreateScalarizedCall( + const Call* op, llvm::Function* f, const std::vector& args) { + llvm::Value* value = llvm::UndefValue::get(LLVMType(op->type)); + for (int i = 0; i < op->type.lanes(); ++i) { + std::vector sargs(args.size()); + for (size_t j = 0; j < args.size(); ++j) { + if (args[j]->getType()->isVectorTy()) { + sargs[j] = builder_->CreateExtractElement(args[j], ConstInt32(i)); + } else { + sargs[j] = args[j]; + } + } + llvm::CallInst* call = builder_->CreateCall(f, sargs); + if (op->is_pure()) { + call->setDoesNotAccessMemory(); + } + call->setDoesNotThrow(); + if (!call->getType()->isVoidTy()) { + value = builder_->CreateInsertElement(value, call, ConstInt32(i)); + } } + return value; } -void CodeGenLLVM::Visit_(const Min* op) { - llvm::Value* a = MakeValue(op->a); - llvm::Value* b = MakeValue(op->b); - llvm::Value* cond = CreateLT(op->a.type(), a, b); - value_ = builder_->CreateSelect(cond, a, b); -} - -void CodeGenLLVM::Visit_(const Max* op) { - llvm::Value* a = MakeValue(op->a); - llvm::Value* b = MakeValue(op->b); - llvm::Value* cond = CreateGT(op->a.type(), a, b); - value_ = builder_->CreateSelect(cond, a, b); -} - -#define DEFINE_CODEGEN_CMP_OP(OP) \ - llvm::Value* CodeGenLLVM::Create ## OP( \ - Type t, llvm::Value* a, llvm::Value* b) { \ - if (t.is_float()) { \ - return builder_->CreateFCmpO ## OP (a, b); \ - } else if (t.is_int()) { \ - return builder_->CreateICmpS ## OP (a, b); \ - } else { \ - return builder_->CreateICmpU ## OP (a, b); \ - } \ - } \ - -DEFINE_CODEGEN_CMP_OP(LT); -DEFINE_CODEGEN_CMP_OP(LE); -DEFINE_CODEGEN_CMP_OP(GT); -DEFINE_CODEGEN_CMP_OP(GE); - -void CodeGenLLVM::Visit_(const LT* op) { - value_ = CreateLT(op->a.type(), MakeValue(op->a), MakeValue(op->b)); -} -void CodeGenLLVM::Visit_(const LE* op) { - value_ = CreateLE(op->a.type(), MakeValue(op->a), MakeValue(op->b)); -} -void CodeGenLLVM::Visit_(const GT* op) { - value_ = CreateGT(op->a.type(), MakeValue(op->a), MakeValue(op->b)); -} -void CodeGenLLVM::Visit_(const GE* op) { - value_ = CreateGE(op->a.type(), MakeValue(op->a), MakeValue(op->b)); -} - -void CodeGenLLVM::Visit_(const EQ* op) { - if (op->a.type().is_float()) { - value_ = builder_->CreateFCmpOEQ(MakeValue(op->a), MakeValue(op->b)); - } else { - value_ = builder_->CreateICmpEQ(MakeValue(op->a), MakeValue(op->b)); - } +llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const { + auto it = var_map_.find(v); + CHECK(it != var_map_.end()) + << "Cannot find " << v->name_hint << " in the var map"; + return it->second; } -void CodeGenLLVM::Visit_(const NE* op) { - if (op->a.type().is_float()) { - value_ = builder_->CreateFCmpONE(MakeValue(op->a), MakeValue(op->b)); +llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) { + auto it = str_map_.find(str); + if (it == str_map_.end()) { + llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1); + llvm::GlobalVariable *global = new llvm::GlobalVariable( + *module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str"); + global->setAlignment(1); + global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str)); + // useful constant value + llvm::Constant* zero = ConstInt32(0); + llvm::Constant* indices[] = {zero, zero}; + llvm::Constant* sptr = llvm::ConstantExpr::getGetElementPtr( + type, global, indices); + str_map_[str] = sptr; + return sptr; } else { - value_ = builder_->CreateICmpNE(MakeValue(op->a), MakeValue(op->b)); + return it->second; } } -void CodeGenLLVM::Visit_(const And* op) { - value_ = builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b)); -} - -void CodeGenLLVM::Visit_(const Or* op) { - value_ = builder_->CreateOr(MakeValue(op->a), MakeValue(op->b)); -} - -void CodeGenLLVM::Visit_(const Not* op) { - value_ = builder_->CreateNot(MakeValue(op->a)); -} - -void CodeGenLLVM::Visit_(const Select* op) { - value_ = builder_->CreateSelect( - MakeValue(op->condition), - MakeValue(op->true_value), - MakeValue(op->false_value)); -} - -void CodeGenLLVM::Visit_(const Let* op) { - llvm::Value* v = MakeValue(op->value); - CHECK(!var_map_.count(op->var.get())); - var_map_[op->var.get()] = v; - value_ = MakeValue(op->body); -} - -void CodeGenLLVM::Visit_(const Broadcast* op) { - value_ = CreateBroadcast(MakeValue(op->value), op->lanes); -} - -void CodeGenLLVM::Visit_(const Ramp* op) { - Type t = op->type; - llvm::Value* base = MakeValue(op->base); - llvm::Value* stride = MakeValue(op->stride); - llvm::Value* value = llvm::UndefValue::get(LLVMType(t)); - for (int i = 0; i < t.lanes(); ++i) { - if (i != 0) { - base = CreateAdd(t, base, stride); - } - value = builder_->CreateInsertElement( - value, base, llvm::ConstantInt::get(t_int32_, i)); +void CodeGenLLVM::CreateParallelFor(const For* op) { + using llvm::BasicBlock; + llvm::Value* min = MakeValue(op->min); + llvm::Value* extent = MakeValue(op->extent); + min = builder_->CreateIntCast(min, t_int64_, op->min.type().is_int()); + extent = builder_->CreateIntCast(extent, t_int64_, op->min.type().is_int()); + // fields to be packed into closure. + Var loop_var(op->loop_var.node_); + Array vfields = ir::UndefinedVars(op->body, {loop_var}); + std::vector fields; + for (Var v : vfields) { + auto it = var_map_.find(v.get()); + CHECK(it != var_map_.end()); + fields.push_back(it->second->getType()); } - value_ = value; -} - -void CodeGenLLVM::Visit_(const Load* op) { - Type t = op->type; - CHECK(!t.is_vector()); + // closure data + llvm::StructType* tcdata = llvm::StructType::create(fields); + llvm::Function* f = llvm::Function::Create( + t_f_tvm_par_for_lambda_, + llvm::Function::PrivateLinkage, + "__tvm_par_for_lambda", module_.get()); + // allocate and setup the closure, call the closure. + llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1)); + llvm::Value* zero = ConstInt32(0); - if (t.is_scalar()) { - llvm::LoadInst* inst = builder_->CreateAlignedLoad( - CreateBufferPtr( - t, - GetVarValue(op->buffer_var.get()), - MakeValue(op->index)), - data_layout_->getTypeAllocSize(LLVMType(t))); - AddAliasInfo(inst, op->buffer_var.get(), op->index); - value_ = inst; - } else { - LOG(FATAL) << "not yet supported"; + for (size_t i = 0; i < vfields.size(); ++i) { + builder_->CreateStore( + var_map_.at(vfields[i].get()), + builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)})); } -} - -void CodeGenLLVM::Visit_(const Store* op) { - llvm::Value* value = MakeValue(op->value); - Type t = op->value.type(); - CHECK(!t.is_vector()); - if (t.is_scalar()) { - llvm::StoreInst* inst = builder_->CreateAlignedStore( - value, - CreateBufferPtr( - t, - GetVarValue(op->buffer_var.get()), - MakeValue(op->index)), - data_layout_->getTypeAllocSize(value->getType())); - AddAliasInfo(inst, op->buffer_var.get(), op->index); - } else { - LOG(FATAL) << "not yet supported"; + BasicBlock* par_for_end = CheckCallSuccess( + builder_->CreateCall( + f_tvm_parallel_for_, + {min, extent, f, builder_->CreatePointerCast(cdata, t_void_p_)})); + // Setup the closure function. + BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f); + builder_->SetInsertPoint(lambda_entry); + auto it = f->arg_begin(); + llvm::Value* begin = &(*it++); + llvm::Value* end = &(*it++); + cdata = &(*it++); + begin = CreateCast(Int(64), op->loop_var.type(), begin); + end = CreateCast(Int(64), op->loop_var.type(), end); + cdata = builder_->CreatePointerCast(cdata, tcdata->getPointerTo()); + // setup new variable map, swap it with current var context. + std::unordered_map new_vmap; + for (size_t i = 0; i < vfields.size(); ++i) { + new_vmap[vfields[i].get()] = + builder_->CreateLoad(builder_->CreateInBoundsGEP( + cdata, {zero, ConstInt32(i)})); } + std::swap(function_, f); + std::swap(new_vmap, var_map_); + CreateSerialFor(begin, end, op->loop_var, op->body); + builder_->CreateRet(ConstInt32(0)); + // swap the var map back, now we are back on track. + std::swap(new_vmap, var_map_); + std::swap(function_, f); + builder_->SetInsertPoint(par_for_end); } -void CodeGenLLVM::Visit_(const Call* op) { - if (op->is_intrinsic(intrinsic::tvm_call_packed)) { - value_ = CreateCallPacked(op); - } else if (op->call_type == Call::Intrinsic || - op->call_type == Call::PureIntrinsic) { - value_ = CreateIntrinstic(op); - } else { - CHECK(op->call_type == Call::Extern || - op->call_type == Call::PureExtern); - value_ = CreateCallExtern(op); - } +void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, + const VarExpr& loop_var, const Stmt& body) { + using llvm::BasicBlock; + Type t = loop_var.type(); + BasicBlock* for_head = BasicBlock::Create( + *ctx_, "for_head", function_); + BasicBlock* for_body = BasicBlock::Create( + *ctx_, "for_body", function_); + BasicBlock* for_end = BasicBlock::Create( + *ctx_, "for_end", function_); + BasicBlock* pre_block = builder_->GetInsertBlock(); + builder_->CreateBr(for_head); + builder_->SetInsertPoint(for_head); + llvm::PHINode* index = builder_->CreatePHI(begin->getType(), 2); + index->addIncoming(begin, pre_block); + llvm::Value* cond = CreateLT(t, index, end); + builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_); + // body of for + builder_->SetInsertPoint(for_body); + var_map_[loop_var.get()] = index; + this->VisitStmt(body); + llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1)); + index->addIncoming(next_index, builder_->GetInsertBlock()); + builder_->CreateBr(for_head); + // end of for + builder_->SetInsertPoint(for_end); } llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { @@ -555,70 +684,292 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { return nullptr; } -llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess(llvm::Value* retcode) { - // create emit codes that checks and load the function. - using llvm::BasicBlock; - BasicBlock* fail_block = BasicBlock::Create( - *ctx_, "call_fail", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "call_end", function_); - llvm::Value* succ = builder_->CreateICmpEQ( - retcode, llvm::ConstantInt::get(t_int_, 0)); - builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_); - builder_->SetInsertPoint(fail_block); - // return the code. - builder_->CreateRet(retcode); - // otherwise set it to be new end. - builder_->SetInsertPoint(end_block); - return end_block; +// visitor overrides +llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) { + return GetVarValue(op); } -void CodeGenLLVM::Visit_(const For* op) { - CHECK(is_zero(op->min)); - if (op->for_type == ForType::Serial) { - CreateSerialFor(ConstInt32(0), MakeValue(op->extent), - op->loop_var, op->body); - } else if (op->for_type == ForType::Parallel) { - CreateParallelFor(op); - } else { - LOG(FATAL) << "cannot handle for type " << op->for_type; - } + +llvm::Value* CodeGenLLVM::VisitExpr_(const Cast* op) { + return CreateCast(op->value.type(), op->type, MakeValue(op->value)); } -void CodeGenLLVM::Visit_(const IfThenElse* op) { - using llvm::BasicBlock; - BasicBlock* then_block = BasicBlock::Create( - *ctx_, "if_then", function_); - BasicBlock* else_block = BasicBlock::Create( - *ctx_, "if_else", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "if_end", function_); - if (!op->else_case.defined()) { - else_block = end_block; - } - // condition. - llvm::Value* cond = MakeValue(op->condition); - bool likely = true; - if (likely) { - builder_->CreateCondBr(cond, then_block, else_block, md_very_likely_branch_); - } else { - builder_->CreateCondBr(cond, then_block, else_block); - } - // then case. - builder_->SetInsertPoint(then_block); - this->Visit(op->then_case); - builder_->CreateBr(end_block); - // else case. - if (op->else_case.defined()) { - builder_->SetInsertPoint(else_block); - this->Visit(op->else_case); - builder_->CreateBr(end_block); - } - builder_->SetInsertPoint(end_block); +llvm::Value* CodeGenLLVM::VisitExpr_(const IntImm* op) { + return llvm::ConstantInt::getSigned(LLVMType(op->type), op->value); } -void CodeGenLLVM::Visit_(const Allocate* op) { - CHECK(!is_zero(op->condition)); - llvm::Value* buf = nullptr; +llvm::Value* CodeGenLLVM::VisitExpr_(const UIntImm* op) { + return llvm::ConstantInt::get(LLVMType(op->type), op->value); +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImm* op) { + return llvm::ConstantFP::get(LLVMType(op->type), op->value); +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) { + return GetConstString(op->value); +} + +#define DEFINE_CODEGEN_BINARY_OP(OP) \ + llvm::Value* CodeGenLLVM::Create ## OP( \ + Type t, llvm::Value* a, llvm::Value *b) { \ + if (t.is_float()) { \ + return builder_->CreateF ## OP (a, b); \ + } else if (t.is_int() && t.bits() >= 32) { \ + return builder_->CreateNSW ## OP (a, b); \ + } else { \ + return builder_->Create ## OP (a, b); \ + } \ + } \ + +DEFINE_CODEGEN_BINARY_OP(Add); +DEFINE_CODEGEN_BINARY_OP(Sub); +DEFINE_CODEGEN_BINARY_OP(Mul); + +llvm::Value* CodeGenLLVM::VisitExpr_(const Add* op) { + return CreateAdd(op->type, MakeValue(op->a), MakeValue(op->b)); +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const Sub* op) { + return CreateSub(op->type, MakeValue(op->a), MakeValue(op->b)); +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const Mul* op) { + return CreateMul(op->type, MakeValue(op->a), MakeValue(op->b)); +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) { + llvm::Value* a = MakeValue(op->a); + int shift; + if (op->type.is_float()) { + return builder_->CreateFDiv(a, MakeValue(op->b)); + } else if ((op->type.is_int() || op->type.is_uint()) && + is_const_power_of_two_integer(op->b, &shift)) { + return builder_->CreateAShr(a, shift); + } else { + llvm::Value* b = MakeValue(op->b); + if (op->type.is_int()) { + return builder_->CreateSDiv(a, b); + } else { + CHECK(op->type.is_uint()); + return builder_->CreateUDiv(a, b); + } + } +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) { + CHECK(!op->type.is_float()) + << "Cannot do mod for float"; + if (op->type.is_int()) { + return builder_->CreateSRem(MakeValue(op->a), MakeValue(op->b)); + } else { + CHECK(op->type.is_uint()); + return builder_->CreateURem(MakeValue(op->a), MakeValue(op->b)); + } +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const Min* op) { + llvm::Value* a = MakeValue(op->a); + llvm::Value* b = MakeValue(op->b); + llvm::Value* cond = CreateLT(op->a.type(), a, b); + return builder_->CreateSelect(cond, a, b); +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const Max* op) { + llvm::Value* a = MakeValue(op->a); + llvm::Value* b = MakeValue(op->b); + llvm::Value* cond = CreateGT(op->a.type(), a, b); + return builder_->CreateSelect(cond, a, b); +} + +#define DEFINE_CODEGEN_CMP_OP(OP) \ + llvm::Value* CodeGenLLVM::Create ## OP( \ + Type t, llvm::Value* a, llvm::Value* b) { \ + if (t.is_float()) { \ + return builder_->CreateFCmpO ## OP (a, b); \ + } else if (t.is_int()) { \ + return builder_->CreateICmpS ## OP (a, b); \ + } else { \ + return builder_->CreateICmpU ## OP (a, b); \ + } \ + } \ + +DEFINE_CODEGEN_CMP_OP(LT); +DEFINE_CODEGEN_CMP_OP(LE); +DEFINE_CODEGEN_CMP_OP(GT); +DEFINE_CODEGEN_CMP_OP(GE); + +llvm::Value* CodeGenLLVM::VisitExpr_(const LT* op) { + return CreateLT(op->a.type(), MakeValue(op->a), MakeValue(op->b)); +} +llvm::Value* CodeGenLLVM::VisitExpr_(const LE* op) { + return CreateLE(op->a.type(), MakeValue(op->a), MakeValue(op->b)); +} +llvm::Value* CodeGenLLVM::VisitExpr_(const GT* op) { + return CreateGT(op->a.type(), MakeValue(op->a), MakeValue(op->b)); +} +llvm::Value* CodeGenLLVM::VisitExpr_(const GE* op) { + return CreateGE(op->a.type(), MakeValue(op->a), MakeValue(op->b)); +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) { + if (op->a.type().is_float()) { + return builder_->CreateFCmpOEQ(MakeValue(op->a), MakeValue(op->b)); + } else { + return builder_->CreateICmpEQ(MakeValue(op->a), MakeValue(op->b)); + } +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) { + if (op->a.type().is_float()) { + return builder_->CreateFCmpONE(MakeValue(op->a), MakeValue(op->b)); + } else { + return builder_->CreateICmpNE(MakeValue(op->a), MakeValue(op->b)); + } +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const And* op) { + return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b)); +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const Or* op) { + return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b)); +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const Not* op) { + return builder_->CreateNot(MakeValue(op->a)); +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) { + return builder_->CreateSelect( + MakeValue(op->condition), + MakeValue(op->true_value), + MakeValue(op->false_value)); +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) { + llvm::Value* v = MakeValue(op->value); + CHECK(!var_map_.count(op->var.get())); + var_map_[op->var.get()] = v; + return MakeValue(op->body); +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const Broadcast* op) { + return CreateBroadcast(MakeValue(op->value), op->lanes); +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) { + Type t = op->type; + llvm::Value* base = MakeValue(op->base); + llvm::Value* stride = MakeValue(op->stride); + llvm::Value* value = llvm::UndefValue::get(LLVMType(t)); + for (int i = 0; i < t.lanes(); ++i) { + if (i != 0) { + base = CreateAdd(t, base, stride); + } + value = builder_->CreateInsertElement( + value, base, llvm::ConstantInt::get(t_int32_, i)); + } + return value; +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) { + Type t = op->type; + CHECK(!t.is_vector()); + + if (t.is_scalar()) { + llvm::LoadInst* inst = builder_->CreateAlignedLoad( + CreateBufferPtr( + t, + GetVarValue(op->buffer_var.get()), + MakeValue(op->index)), + data_layout_->getTypeAllocSize(LLVMType(t))); + AddAliasInfo(inst, op->buffer_var.get(), op->index); + return inst; + } else { + LOG(FATAL) << "not yet supported"; + return nullptr; + } +} + +llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) { + if (op->is_intrinsic(intrinsic::tvm_call_packed)) { + return CreateCallPacked(op); + } else if (op->call_type == Call::Intrinsic || + op->call_type == Call::PureIntrinsic) { + return CreateIntrinstic(op); + } else { + CHECK(op->call_type == Call::Extern || + op->call_type == Call::PureExtern); + return CreateCallExtern(op); + } +} + +// stmts +void CodeGenLLVM::VisitStmt_(const Store* op) { + llvm::Value* value = MakeValue(op->value); + Type t = op->value.type(); + CHECK(!t.is_vector()); + if (t.is_scalar()) { + llvm::StoreInst* inst = builder_->CreateAlignedStore( + value, + CreateBufferPtr( + t, + GetVarValue(op->buffer_var.get()), + MakeValue(op->index)), + data_layout_->getTypeAllocSize(value->getType())); + AddAliasInfo(inst, op->buffer_var.get(), op->index); + } else { + LOG(FATAL) << "not yet supported"; + } +} + +void CodeGenLLVM::VisitStmt_(const For* op) { + CHECK(is_zero(op->min)); + if (op->for_type == ForType::Serial) { + CreateSerialFor(ConstInt32(0), MakeValue(op->extent), + op->loop_var, op->body); + } else if (op->for_type == ForType::Parallel) { + CreateParallelFor(op); + } else { + LOG(FATAL) << "cannot handle for type " << op->for_type; + } +} + +void CodeGenLLVM::VisitStmt_(const IfThenElse* op) { + using llvm::BasicBlock; + BasicBlock* then_block = BasicBlock::Create( + *ctx_, "if_then", function_); + BasicBlock* else_block = BasicBlock::Create( + *ctx_, "if_else", function_); + BasicBlock* end_block = BasicBlock::Create( + *ctx_, "if_end", function_); + if (!op->else_case.defined()) { + else_block = end_block; + } + // condition. + llvm::Value* cond = MakeValue(op->condition); + bool likely = true; + if (likely) { + builder_->CreateCondBr(cond, then_block, else_block, md_very_likely_branch_); + } else { + builder_->CreateCondBr(cond, then_block, else_block); + } + // then case. + builder_->SetInsertPoint(then_block); + this->VisitStmt(op->then_case); + builder_->CreateBr(end_block); + // else case. + if (op->else_case.defined()) { + builder_->SetInsertPoint(else_block); + this->VisitStmt(op->else_case); + builder_->CreateBr(end_block); + } + builder_->SetInsertPoint(end_block); +} + +void CodeGenLLVM::VisitStmt_(const Allocate* op) { + CHECK(!is_zero(op->condition)); + llvm::Value* buf = nullptr; if (op->new_expr.defined()) { CHECK_EQ(op->free_function, "nop"); buf = MakeValue(op->new_expr); @@ -634,11 +985,11 @@ void CodeGenLLVM::Visit_(const Allocate* op) { var_map_[op->buffer_var.get()] = buf; } -void CodeGenLLVM::Visit_(const AttrStmt* op) { - this->Visit(op->body); +void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { + this->VisitStmt(op->body); } -void CodeGenLLVM::Visit_(const AssertStmt* op) { +void CodeGenLLVM::VisitStmt_(const AssertStmt* op) { using llvm::BasicBlock; llvm::Value* cond = MakeValue(op->condition); std::ostringstream os; @@ -660,359 +1011,23 @@ void CodeGenLLVM::Visit_(const AssertStmt* op) { builder_->SetInsertPoint(end_block); } -void CodeGenLLVM::Visit_(const LetStmt* op) { +void CodeGenLLVM::VisitStmt_(const LetStmt* op) { llvm::Value* v = MakeValue(op->value); CHECK(!var_map_.count(op->var.get())); var_map_[op->var.get()] = v; - this->Visit(op->body); + this->VisitStmt(op->body); } - -void CodeGenLLVM::AddAliasInfo( - llvm::Instruction* inst, const Variable* buffer, Expr index) { - int base = 0, width = 0; - // create meta-data for alias analysis - // Use a group of binary tree ranges. - const Ramp* ramp = index.as(); - if (ramp) { - int base, stride; - if (arith::GetConstInt(ramp->base, &base) && - arith::GetConstInt(ramp->stride, &stride)) { - int xwith = ramp->lanes * stride; - width = 1; - while (width < xwith) { - width *= 2; - } - while (base % width) { - base -= base % width; - width *= 2; - } - } - } else { - if (arith::GetConstInt(index, &base)) width = 1; - } - - llvm::MDNode* meta = md_tbaa_root_; - std::ostringstream buffer_addr; - buffer_addr << buffer; - meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta); - // create a tree-shape access structure. - if (width != 0) { - for (int w = 1024; w >= width; w /= 2) { - int b = (base / w) * w; - std::stringstream os; - os << buffer << ".w" << w << ".b" << b; - meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta); - } - } - inst->setMetadata( - "tbaa", - md_builder_->createTBAAStructTagNode(meta, meta, 0)); +void CodeGenLLVM::VisitStmt_(const Block* op) { + VisitStmt(op->first); + if (op->rest.defined()) VisitStmt(op->rest); +} +void CodeGenLLVM::VisitStmt_(const Evaluate *op) { + MakeValue(op->value); +} +void CodeGenLLVM::VisitStmt_(const ProducerConsumer* op) { + VisitStmt(op->body); } -llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { - llvm::Constant* init = llvm::UndefValue::get( - llvm::VectorType::get(value->getType(), lanes)); - llvm::Constant* zero = ConstInt32(0); - value = builder_->CreateInsertElement(init, value, zero); - llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero); - return builder_->CreateShuffleVector(value, init, mask); -} - -llvm::Value* CodeGenLLVM::CreateBufferPtr( - Type t, llvm::Value* buffer, llvm::Value* index) { - llvm::Type* elem_type = buffer->getType(); - unsigned address_space = elem_type->getPointerAddressSpace(); - llvm::Type* load_type = LLVMType(t)->getPointerTo(address_space); - - if (load_type != elem_type) { - buffer = builder_->CreatePointerCast(buffer, load_type); - } - llvm::Constant* cindex = llvm::dyn_cast(index); - if (cindex && cindex->isZeroValue()) { - return buffer; - } - return builder_->CreateInBoundsGEP(buffer, index); -} - -llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) { - llvm::Type * target = LLVMType(to); - if (value->getType() == target) return value; - if (from.is_handle() && from.is_handle()) { - return builder_->CreateBitCast(value, target); - } else if (!from.is_float() && !to.is_float()) { - return builder_->CreateIntCast(value, target, from.is_int()); - } else if (from.is_float() && to.is_int()) { - return builder_->CreateFPToSI(value, target); - } else if (from.is_float() && to.is_uint()) { - if (to.bits() < 8) { - value = builder_->CreateFPToUI(value, LLVMType(to.with_bits(8))); - return builder_->CreateIntCast(value, target, false); - } else { - return builder_->CreateFPToUI(value, target); - } - } else if (from.is_int() && to.is_float()) { - return builder_->CreateSIToFP(value, target); - } else if (from.is_uint() && to.is_float()) { - return builder_->CreateUIToFP(value, target); - } else { - CHECK(from.is_float() && to.is_float()); - return builder_->CreateFPCast(value, target); - } -} - -llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) { - using llvm::BasicBlock; - // We will store the packed function handle in global space. - // Initialize it during the first call. - llvm::DataLayout layout(module_.get()); - uint64_t align = layout.getTypeAllocSize(t_tvm_func_handle_); - auto it = func_handle_map_.find(fname); - - llvm::GlobalVariable* hptr; - if (it == func_handle_map_.end()) { - // create global location for the handle - // create the function handle - hptr = new llvm::GlobalVariable( - *module_, t_tvm_func_handle_, false, - llvm::GlobalValue::PrivateLinkage, 0, ".tvm_func"); - hptr->setAlignment(align); - hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_)); - func_handle_map_[fname] = hptr; - } else { - hptr = it->second; - } - // create emit codes that checks and load the function. - BasicBlock* pre_block = builder_->GetInsertBlock(); - BasicBlock* init_block = BasicBlock::Create( - *ctx_, "handle_init", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "handle_init_end", function_); - llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align); - llvm::Value* handle_not_null = builder_->CreateICmpNE( - handle, llvm::Constant::getNullValue(t_tvm_func_handle_)); - builder_->CreateCondBr( - handle_not_null, end_block, init_block, md_very_likely_branch_); - // Initialize the handle if needed. - builder_->SetInsertPoint(init_block); - llvm::Value* out = builder_->CreateAlloca(t_tvm_func_handle_); - llvm::Value* ctx = builder_->CreateLoad(gv_mod_ctx_); - llvm::Value* retcode = builder_->CreateCall( - f_tvm_get_func_from_env_, {ctx, GetConstString(fname), out}); - init_block = CheckCallSuccess(retcode); - llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align); - builder_->CreateBr(end_block); - // end block - builder_->SetInsertPoint(end_block); - llvm::PHINode* phi = builder_->CreatePHI(t_tvm_func_handle_, 2); - phi->addIncoming(handle, pre_block); - phi->addIncoming(loaded_handle, init_block); - return phi; -} - -llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) { - CHECK_GE(op->args.size(), 1U); - std::string func_name = op->args[0].as()->value; - llvm::Value* handle = GetPackedFuncHandle(func_name); - // call the function - unsigned nargs = static_cast(op->args.size() - 1); - llvm::Value* targs = builder_->CreateAlloca( - t_tvm_value_, ConstInt32(nargs)); - llvm::Value* tcodes = builder_->CreateAlloca( - t_int_, ConstInt32(nargs)); - for (unsigned i = 0; i < nargs; ++i) { - Expr expr = op->args[i + 1]; - Type t = expr.type(); - CHECK_EQ(t.lanes(), 1); - // Always pass via 64 bit value. - // For handle type, Handle(64) maps to 32 bit void* in 32bit platform. - Type api_type = t.with_bits(64); - llvm::Value* value = CreateCast(t, api_type, MakeValue(expr)); - llvm::Value* store_ptr = builder_->CreatePointerCast( - builder_->CreateInBoundsGEP(targs, ConstInt32(i)), - LLVMType(api_type)->getPointerTo()); - builder_->CreateAlignedStore(value, store_ptr, 8); - builder_->CreateAlignedStore( - ConstInt32(t.code()), - builder_->CreateInBoundsGEP(tcodes, ConstInt32(i)), 4); - } - llvm::Value* ret_value = builder_->CreateAlloca(t_tvm_value_); - llvm::Value* ret_tcode = builder_->CreateAlloca(t_int_); - CheckCallSuccess( - builder_->CreateCall( - f_tvm_func_call_, - {handle, targs, tcodes, ConstInt32(nargs), ret_value, ret_tcode})); - Type r_type = op->type; - Type r_api_type = op->type.with_bits(64); - llvm::Value* rvalue = - builder_->CreateAlignedLoad( - builder_->CreatePointerCast( - ret_value, LLVMType(r_api_type)->getPointerTo()), 8); - rvalue = CreateCast(r_api_type, r_type, rvalue); - return rvalue; -} - -llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) { - std::vector arg_values(op->args.size()); - for (size_t i = 0; i < op->args.size(); ++i) { - arg_values[i] = MakeValue(op->args[i]); - } - if (op->type.is_scalar()) { - llvm::Function* f = module_->getFunction(op->name); - if (f) { - return builder_->CreateCall(f, arg_values); - } else { - LOG(FATAL) << "cannot find function " << op->name; - } - } else { - llvm::Function* f = module_->getFunction(op->name); - if (f) { - return CreateScalarizedCall(op, f, arg_values); - } else { - LOG(FATAL) << "cannot find function " << op->name; - } - } - return nullptr; -} - -llvm::Value* CodeGenLLVM::CreateScalarizedCall( - const Call* op, llvm::Function* f, const std::vector& args) { - llvm::Value* value = llvm::UndefValue::get(LLVMType(op->type)); - for (int i = 0; i < op->type.lanes(); ++i) { - std::vector sargs(args.size()); - for (size_t j = 0; j < args.size(); ++j) { - if (args[j]->getType()->isVectorTy()) { - sargs[j] = builder_->CreateExtractElement(args[j], ConstInt32(i)); - } else { - sargs[j] = args[j]; - } - } - llvm::CallInst* call = builder_->CreateCall(f, sargs); - if (op->is_pure()) { - call->setDoesNotAccessMemory(); - } - call->setDoesNotThrow(); - if (!call->getType()->isVoidTy()) { - value = builder_->CreateInsertElement(value, call, ConstInt32(i)); - } - } - return value; -} - -llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const { - auto it = var_map_.find(v); - CHECK(it != var_map_.end()) - << "Cannot find " << v->name_hint << " in the var map"; - return it->second; -} - -llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) { - auto it = str_map_.find(str); - if (it == str_map_.end()) { - llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1); - llvm::GlobalVariable *global = new llvm::GlobalVariable( - *module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str"); - global->setAlignment(1); - global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str)); - // useful constant value - llvm::Constant* zero = ConstInt32(0); - llvm::Constant* indices[] = {zero, zero}; - llvm::Constant* sptr = llvm::ConstantExpr::getGetElementPtr( - type, global, indices); - str_map_[str] = sptr; - return sptr; - } else { - return it->second; - } -} - -void CodeGenLLVM::CreateParallelFor(const For* op) { - using llvm::BasicBlock; - llvm::Value* min = MakeValue(op->min); - llvm::Value* extent = MakeValue(op->extent); - min = builder_->CreateIntCast(min, t_int64_, op->min.type().is_int()); - extent = builder_->CreateIntCast(extent, t_int64_, op->min.type().is_int()); - // fields to be packed into closure. - Var loop_var(op->loop_var.node_); - Array vfields = ir::UndefinedVars(op->body, {loop_var}); - std::vector fields; - for (Var v : vfields) { - auto it = var_map_.find(v.get()); - CHECK(it != var_map_.end()); - fields.push_back(it->second->getType()); - } - // closure data - llvm::StructType* tcdata = llvm::StructType::create(fields); - llvm::Function* f = llvm::Function::Create( - t_f_tvm_par_for_lambda_, - llvm::Function::PrivateLinkage, - "__tvm_par_for_lambda", module_.get()); - // allocate and setup the closure, call the closure. - llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1)); - llvm::Value* zero = ConstInt32(0); - - for (size_t i = 0; i < vfields.size(); ++i) { - builder_->CreateStore( - var_map_.at(vfields[i].get()), - builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)})); - } - BasicBlock* par_for_end = CheckCallSuccess( - builder_->CreateCall( - f_tvm_parallel_for_, - {min, extent, f, builder_->CreatePointerCast(cdata, t_void_p_)})); - // Setup the closure function. - BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f); - builder_->SetInsertPoint(lambda_entry); - auto it = f->arg_begin(); - llvm::Value* begin = &(*it++); - llvm::Value* end = &(*it++); - cdata = &(*it++); - begin = CreateCast(Int(64), op->loop_var.type(), begin); - end = CreateCast(Int(64), op->loop_var.type(), end); - cdata = builder_->CreatePointerCast(cdata, tcdata->getPointerTo()); - // setup new variable map, swap it with current var context. - std::unordered_map new_vmap; - for (size_t i = 0; i < vfields.size(); ++i) { - new_vmap[vfields[i].get()] = - builder_->CreateLoad(builder_->CreateInBoundsGEP( - cdata, {zero, ConstInt32(i)})); - } - std::swap(function_, f); - std::swap(new_vmap, var_map_); - CreateSerialFor(begin, end, op->loop_var, op->body); - builder_->CreateRet(ConstInt32(0)); - // swap the var map back, now we are back on track. - std::swap(new_vmap, var_map_); - std::swap(function_, f); - builder_->SetInsertPoint(par_for_end); -} - -void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, - const VarExpr& loop_var, const Stmt& body) { - using llvm::BasicBlock; - Type t = loop_var.type(); - BasicBlock* for_head = BasicBlock::Create( - *ctx_, "for_head", function_); - BasicBlock* for_body = BasicBlock::Create( - *ctx_, "for_body", function_); - BasicBlock* for_end = BasicBlock::Create( - *ctx_, "for_end", function_); - BasicBlock* pre_block = builder_->GetInsertBlock(); - builder_->CreateBr(for_head); - builder_->SetInsertPoint(for_head); - llvm::PHINode* index = builder_->CreatePHI(begin->getType(), 2); - index->addIncoming(begin, pre_block); - llvm::Value* cond = CreateLT(t, index, end); - builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_); - // body of for - builder_->SetInsertPoint(for_body); - var_map_[loop_var.get()] = index; - this->Visit(body); - llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1)); - index->addIncoming(next_index, builder_->GetInsertBlock()); - builder_->CreateBr(for_head); - // end of for - builder_->SetInsertPoint(for_end); -} } // namespace codegen } // namespace tvm #endif // TVM_LLVM_VERSION diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 3f7c197c270d..aed75a866d16 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -8,7 +8,7 @@ #ifdef TVM_LLVM_VERSION #include -#include +#include #include #include #include @@ -23,7 +23,9 @@ using namespace ir; /*! * \brief A base class to generate a LLVM. */ -class CodeGenLLVM : public IRVisitor { +class CodeGenLLVM : + public ExprFunctor, + public StmtFunctor { public: /*! * \brief Initialize the code generator with given context @@ -55,52 +57,52 @@ class CodeGenLLVM : public IRVisitor { * \return created value. */ llvm::Value* MakeValue(const Expr& e) { - value_ = nullptr; - this->Visit(e); - CHECK(value_ != nullptr); - return value_; + return VisitExpr(e); } // Short hande code to get a constant int 32 llvm::Constant* ConstInt32(unsigned value) const { return llvm::ConstantInt::get(t_int32_, value); } // override codegen - void Visit_(const Variable* op) final; - void Visit_(const Cast* op) final; - void Visit_(const IntImm* op) final; - void Visit_(const UIntImm* op) final; - void Visit_(const FloatImm* op) final; - void Visit_(const StringImm* op) final; - void Visit_(const Add* op) final; - void Visit_(const Sub* op) final; - void Visit_(const Mul* op) final; - void Visit_(const Div* op) final; - void Visit_(const Mod* op) final; - void Visit_(const Min* op) final; - void Visit_(const Max* op) final; - void Visit_(const LT* op) final; - void Visit_(const LE* op) final; - void Visit_(const GT* op) final; - void Visit_(const GE* op) final; - void Visit_(const EQ* op) final; - void Visit_(const NE* op) final; - void Visit_(const And* op) final; - void Visit_(const Or* op) final; - void Visit_(const Not* op) final; - void Visit_(const Select* op) final; - void Visit_(const Let* op) final; - void Visit_(const Load* op) final; - void Visit_(const Call* op) final; - void Visit_(const Ramp* op) final; - void Visit_(const Broadcast* op) final; + llvm::Value* VisitExpr_(const Variable* op) override; + llvm::Value* VisitExpr_(const Cast* op) override; + llvm::Value* VisitExpr_(const IntImm* op) override; + llvm::Value* VisitExpr_(const UIntImm* op) override; + llvm::Value* VisitExpr_(const FloatImm* op) override; + llvm::Value* VisitExpr_(const StringImm* op) override; + llvm::Value* VisitExpr_(const Add* op) override; + llvm::Value* VisitExpr_(const Sub* op) override; + llvm::Value* VisitExpr_(const Mul* op) override; + llvm::Value* VisitExpr_(const Div* op) override; + llvm::Value* VisitExpr_(const Mod* op) override; + llvm::Value* VisitExpr_(const Min* op) override; + llvm::Value* VisitExpr_(const Max* op) override; + llvm::Value* VisitExpr_(const LT* op) override; + llvm::Value* VisitExpr_(const LE* op) override; + llvm::Value* VisitExpr_(const GT* op) override; + llvm::Value* VisitExpr_(const GE* op) override; + llvm::Value* VisitExpr_(const EQ* op) override; + llvm::Value* VisitExpr_(const NE* op) override; + llvm::Value* VisitExpr_(const And* op) override; + llvm::Value* VisitExpr_(const Or* op) override; + llvm::Value* VisitExpr_(const Not* op) override; + llvm::Value* VisitExpr_(const Select* op) override; + llvm::Value* VisitExpr_(const Let* op) override; + llvm::Value* VisitExpr_(const Load* op) override; + llvm::Value* VisitExpr_(const Call* op) override; + llvm::Value* VisitExpr_(const Ramp* op) override; + llvm::Value* VisitExpr_(const Broadcast* op) override; // stmt - void Visit_(const Store* op) final; - void Visit_(const For* op) final; - void Visit_(const IfThenElse* op) final; - void Visit_(const Allocate* op) final; - void Visit_(const AttrStmt* op) override; - void Visit_(const AssertStmt* op) final; - void Visit_(const LetStmt* op) final; + void VisitStmt_(const Store* op) override; + void VisitStmt_(const For* op) override; + void VisitStmt_(const IfThenElse* op) override; + void VisitStmt_(const Allocate* op) override; + void VisitStmt_(const AttrStmt* op) override; + void VisitStmt_(const AssertStmt* op) override; + void VisitStmt_(const LetStmt* op) override; + void VisitStmt_(const Block* op) override; + void VisitStmt_(const Evaluate* op) override; + void VisitStmt_(const ProducerConsumer* op) override; // create intrinstic given call virtual llvm::Value* CreateIntrinstic(const Call* op); // create extern function call @@ -160,8 +162,6 @@ class CodeGenLLVM : public IRVisitor { llvm::Function* f_tvm_parallel_for_{nullptr}; // The acting body llvm::BasicBlock* block_{nullptr}; - // Last value returned codegen call. - llvm::Value* value_{nullptr}; private: // comparison op From c21807be08c330865e1dc2d3d2e13904976e5ae3 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 28 Feb 2017 15:20:31 -0800 Subject: [PATCH 5/7] [IRFunctor] Migrate canonical --- src/arithmetic/canonical.cc | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index ae95b04a5305..f7ee6f45aecc 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -162,10 +162,8 @@ class Canonical::Internal : public IRMutator { return stmt; } Expr MutateExpr_(Expr expr) { - static const FMutateExpr& f = Internal::vtable_expr(); stack_.push_back(StackEntry()); - expr = (f.can_dispatch(expr) ? - f(expr, expr, this) : IRMutator::Mutate(expr)); + expr = IRMutator::Mutate(expr); // update result of parent automatically during pop if (stack_.size() > 1) { StackEntry& back = stack_[stack_.size() - 1]; @@ -200,7 +198,7 @@ class Canonical::Internal : public IRMutator { return (t.lanes() == 1 && (t.is_int() || t.is_uint())); } // Add - Expr Mutate_(const Add* op, const Expr& e) { + Expr Mutate_(const Add* op, const Expr& e) final { if (!EnableOpt(op->type)) { return Binary(op, e, this); } @@ -212,7 +210,7 @@ class Canonical::Internal : public IRMutator { return SumAdd(a, b, +1); } // Sub - Expr Mutate_(const Sub* op, const Expr& e) { + Expr Mutate_(const Sub* op, const Expr& e) final { if (!EnableOpt(op->type)) { return Binary(op, e, this); } @@ -224,7 +222,7 @@ class Canonical::Internal : public IRMutator { return SumAdd(a, b, -1); } // Mul - Expr Mutate_(const Mul* op, const Expr& e) { + Expr Mutate_(const Mul* op, const Expr& e) final { if (!EnableOpt(op->type)) { return Binary(op, e, this); } @@ -463,17 +461,6 @@ class Canonical::Internal : public IRMutator { using CInternal = Canonical::Internal; -#define DISPATCH_EXPR(OP) \ - set_dispatch([](const OP *op, const Expr& e, IRMutator* p) { \ - return static_cast(p)->Mutate_(op, e); }) - -TVM_STATIC_IR_FUNCTOR(CInternal, vtable_expr) -.DISPATCH_EXPR(Add) -.DISPATCH_EXPR(Sub) -.DISPATCH_EXPR(Mul) -.DISPATCH_EXPR(LT); - - Canonical::Canonical() : ptr_(std::make_shared()) {} From 6f70edb12181467efdf2081f36c68687e6cbc27d Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 28 Feb 2017 15:31:27 -0800 Subject: [PATCH 6/7] [IRFunctor] Migrate vectorize --- src/pass/vectorize_loop.cc | 183 +++++++++++++++++++------------------ 1 file changed, 96 insertions(+), 87 deletions(-) diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index 109b1326f8cb..18f57217d1ad 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -69,11 +69,71 @@ class Vectorizer : public IRMutator { } // user mutate from parent. using IRMutator::Mutate; - // override mutate - Expr Mutate(Expr expr) final { - static const FMutateExpr& f = Vectorizer::vtable_expr(); - return (f.can_dispatch(expr) ? - f(expr, expr, this) : IRMutator::Mutate(expr)); + + Expr Mutate_(const Add* op, const Expr &e) final { + return AddSubVec(op, e); + } + Expr Mutate_(const Sub* op, const Expr &e) final { + return AddSubVec(op, e); + } + Expr Mutate_(const Mul* op, const Expr &e) final { + return BinaryVec(op, e); + } + Expr Mutate_(const Div* op, const Expr &e) final { + return BinaryVec(op, e); + } + Expr Mutate_(const Mod* op, const Expr &e) final { + return BinaryVec(op, e); + } + Expr Mutate_(const Min* op, const Expr &e) final { + return BinaryVec(op, e); + } + Expr Mutate_(const Max* op, const Expr &e) final { + return BinaryVec(op, e); + } + Expr Mutate_(const EQ* op, const Expr &e) final { + return BinaryVec(op, e); + } + Expr Mutate_(const NE* op, const Expr &e) final { + return BinaryVec(op, e); + } + Expr Mutate_(const LT* op, const Expr &e) final { + return BinaryVec(op, e); + } + Expr Mutate_(const GT* op, const Expr &e) final { + return BinaryVec(op, e); + } + Expr Mutate_(const GE* op, const Expr &e) final { + return BinaryVec(op, e); + } + Expr Mutate_(const And* op, const Expr &e) final { + return BinaryVec(op, e); + } + Expr Mutate_(const Or* op, const Expr &e) final { + return BinaryVec(op, e); + } + Expr Mutate_(const Select *op, const Expr& e) final { + Expr cond = this->Mutate(op->condition); + Expr t = this->Mutate(op->true_value); + Expr f = this->Mutate(op->false_value); + if (cond.same_as(op->condition) && + t.same_as(op->true_value) && + f.same_as(op->false_value)) { + return e; + } else { + int lanes = std::max(std::max( + cond.type().lanes(), + t.type().lanes()), f.type().lanes()); + return Select::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes)); + } + } + Expr Mutate_(const Cast *op, const Expr& e) final { + Expr value = this->Mutate(op->value); + if (value.same_as(op->value)) { + return e; + } else { + return Cast::make(op->type.with_lanes(value.type().lanes()), value); + } } // Variable Expr Mutate_(const Variable* v, const Expr& e) final { @@ -235,10 +295,6 @@ class Vectorizer : public IRMutator { stmt = Substitute(stmt, {{var_, idx}}); return For::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt); } - // The overloads for vectorize. - static FMutateExpr& vtable_expr() { // NOLINT(*) - static FMutateExpr inst; return inst; - } private: // variable to be replaced @@ -273,90 +329,43 @@ class Vectorizer : public IRMutator { if (!changed) return arr; return Array(new_arr); } -}; - -// binary vectorize -template -inline Expr BinaryVec(const T* op, const Expr& e, IRMutator* m) { - Expr a = m->Mutate(op->a); - Expr b = m->Mutate(op->b); - if (a.same_as(op->a) && - b.same_as(op->b)) { - return e; - } else { - int lanes = std::max(a.type().lanes(), b.type().lanes()); - return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); - } -} - -template -inline Expr AddSubVec(const T* op, const Expr& e, IRMutator* m) { - Expr a = m->Mutate(op->a); - Expr b = m->Mutate(op->b); - if (a.same_as(op->a) && - b.same_as(op->b)) { - return e; - } else { - int lanes = std::max(a.type().lanes(), b.type().lanes()); - if (lanes != 1) { - const Ramp* b_ramp = b.as(); - const Ramp* a_ramp = a.as(); - if (a.type().lanes() == 1 && b_ramp) { - return Ramp::make( - arith::ComputeExpr(a, b_ramp->base), b_ramp->stride, b_ramp->lanes); - } - if (b.type().lanes() == 1 && a_ramp) { - return Ramp::make( - arith::ComputeExpr(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); - } - } - return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); - } -} - -TVM_STATIC_IR_FUNCTOR(Vectorizer, vtable_expr) -.set_dispatch(AddSubVec) -.set_dispatch(AddSubVec) -.set_dispatch(BinaryVec) -.set_dispatch
(BinaryVec
) -.set_dispatch(BinaryVec) -.set_dispatch(BinaryVec) -.set_dispatch(BinaryVec) -.set_dispatch(BinaryVec) -.set_dispatch(BinaryVec) -.set_dispatch(BinaryVec) -.set_dispatch(BinaryVec) -.set_dispatch(BinaryVec) -.set_dispatch(BinaryVec) -.set_dispatch(BinaryVec) -.set_dispatch(BinaryVec); - - -TVM_STATIC_IR_FUNCTOR(Vectorizer, vtable_expr) -.set_dispatch([](const Select *op, CodeGenStackVM* p) { - p->Push(op->true_value); - p->Push(op->false_value); - p->Push(op->condition); - p->PushOp(StackVM::SELECT); - }) -.set_dispatch([](const AssertStmt *op, CodeGenStackVM* p) { - if (op->message.as()) { - int sid = p->GetStrID(op->message.as()->value); - p->Push(op->condition); - p->PushOp(StackVM::ASSERT, sid); - } - }) -.set_dispatch([](const AttrStmt *op, CodeGenStackVM* p) { - p->Push(op->body); - }) -.set_dispatch([](const Let *op, CodeGenStackVM* p) { - p->Push(op->value); - int64_t vid = p->AllocVarID(op->var.get()); - p->PushOp(StackVM::STORE_HEAP, static_cast(vid)); - p->Push(op->body); - }) -.set_dispatch([](const Load *op, CodeGenStackVM* p) { - p->Push_(op); - }) -.set_dispatch([](const Store *op, CodeGenStackVM* p) { - p->Push_(op); - }) -.set_dispatch([](const Allocate *op, CodeGenStackVM* p) { - p->Push_(op); - }) -.set_dispatch([](const Call *op, CodeGenStackVM* p) { - p->Push_(op); - }); +void CodeGenStackVM::VisitExpr_(const StringImm *op) { + int sid = this->GetStrID(op->value); + this->PushOp(StackVM::PUSH_I64, sid); +} + +void CodeGenStackVM::VisitExpr_(const IntImm *op) { + CHECK(op->value >= std::numeric_limits::min() && + op->value <= std::numeric_limits::max()) + << "Int constant exceed bound"; + this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); +} + +void CodeGenStackVM::VisitExpr_(const UIntImm *op) { + CHECK(op->value <= std::numeric_limits::max()) + << "Int constant exceed bound"; + this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); +} + +void CodeGenStackVM::VisitExpr_(const FloatImm *op) { + LOG(FATAL) << "Float Imm is not supported"; +} + +void CodeGenStackVM::VisitExpr_(const Variable *op) { + int vid = this->GetVarID(op); + this->PushOp(StackVM::LOAD_HEAP, vid); +} + +void CodeGenStackVM::VisitExpr_(const Cast *op) { + this->Push(op->value); + PushCast(op->type, op->value.type()); +} + +void CodeGenStackVM::VisitExpr_(const Add *op) { + PushBinary(StackVM::ADD_I64, op->a, op->b); +} + +void CodeGenStackVM::VisitExpr_(const Sub *op) { + PushBinary(StackVM::SUB_I64, op->a, op->b); +} + +void CodeGenStackVM::VisitExpr_(const Mul *op) { + PushBinary(StackVM::MUL_I64, op->a, op->b); +} + +void CodeGenStackVM::VisitExpr_(const Div *op) { + PushBinary(StackVM::DIV_I64, op->a, op->b); +} + +void CodeGenStackVM::VisitExpr_(const Mod *op) { + PushBinary(StackVM::MOD_I64, op->a, op->b); +} + +void CodeGenStackVM::VisitExpr_(const Min *op) { + this->Push(op->a); + this->Push(op->b); + this->PushOp(StackVM::PUSH_VALUE, -1); + this->PushOp(StackVM::PUSH_VALUE, -1); + this->PushOp(StackVM::LT_I64); + this->PushOp(StackVM::SELECT); +} + +void CodeGenStackVM::VisitExpr_(const Max *op) { + this->Push(op->a); + this->Push(op->b); + this->PushOp(StackVM::PUSH_VALUE, 0); + this->PushOp(StackVM::PUSH_VALUE, -2); + this->PushOp(StackVM::LT_I64); + this->PushOp(StackVM::SELECT); +} + +void CodeGenStackVM::VisitExpr_(const EQ *op) { + PushBinary(StackVM::EQ_I64, op->a, op->b); +} + +void CodeGenStackVM::VisitExpr_(const LE *op) { + PushBinary(StackVM::LE_I64, op->a, op->b); +} + +void CodeGenStackVM::VisitExpr_(const NE *op) { + PushBinary(StackVM::EQ_I64, op->a, op->b); + this->PushOp(StackVM::NOT); +} + +void CodeGenStackVM::VisitExpr_(const LT *op) { + PushBinary(StackVM::LT_I64, op->a, op->b); +} + +void CodeGenStackVM::VisitExpr_(const GE *op) { + PushBinary(StackVM::LT_I64, op->a, op->b); + this->PushOp(StackVM::NOT); +} + +void CodeGenStackVM::VisitExpr_(const GT *op) { + PushBinary(StackVM::LE_I64, op->a, op->b); + this->PushOp(StackVM::NOT); +} + +void CodeGenStackVM::VisitExpr_(const And *op) { + this->Push(op->a); + int64_t pc_jump = this->GetPC(); + int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_FALSE, 0); + this->PushOp(StackVM::POP); + this->Push(op->b); + int64_t diff = this->GetPC() - pc_jump; + this->SetOperand(opr_index, diff); +} + +void CodeGenStackVM::VisitExpr_(const Or *op) { + this->Push(op->a); + int64_t pc_jump = this->GetPC(); + int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_TRUE, 0); + this->Push(op->b); + int64_t diff = this->GetPC() - pc_jump; + this->SetOperand(opr_index, diff); +} + +void CodeGenStackVM::VisitExpr_(const Not* op) { + this->PushOp(StackVM::NOT); +} + +void CodeGenStackVM::VisitStmt_(const ProducerConsumer *op) { + this->Push(op->body); +} + +void CodeGenStackVM::VisitStmt_(const For *op) { + CHECK(is_zero(op->min)); + int vid = this->AllocVarID(op->loop_var.get()); + this->PushOp(StackVM::PUSH_I64, 0); + int64_t loop_head = this->GetPC(); + this->PushOp(StackVM::STORE_HEAP, vid); + this->PushOp(StackVM::LOAD_HEAP, vid); + this->Push(op->extent); + this->PushOp(StackVM::LT_I64); + int64_t label_fjump = this->GetPC(); + int64_t foward_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0); + this->PushOp(StackVM::POP); + this->Push(op->body); + this->PushOp(StackVM::LOAD_HEAP, vid); + this->PushOp(StackVM::PUSH_I64, 1); + this->PushOp(StackVM::ADD_I64); + int64_t label_bjump = this->GetPC(); + int64_t backward_jump = this->PushOp(StackVM::RJUMP, 0); + int64_t loop_end = this->GetPC(); + this->PushOp(StackVM::POP); + this->SetOperand(foward_jump, loop_end - label_fjump); + this->SetOperand(backward_jump, loop_head - label_bjump); +} + +void CodeGenStackVM::VisitStmt_(const Block *op) { + this->Push(op->first); + if (op->rest.defined()) this->Push(op->rest); +} + +void CodeGenStackVM::VisitStmt_(const Evaluate *op) { + if (is_const(op->value)) return; + this->Push(op->value); + this->PushOp(StackVM::POP); +} + +void CodeGenStackVM::VisitStmt_(const IfThenElse *op) { + this->Push(op->condition); + int64_t label_ejump = this->GetPC(); + int64_t else_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0); + this->PushOp(StackVM::POP); + this->Push(op->then_case); + if (op->else_case.defined()) { + int64_t label_then_jump = this->GetPC(); + int64_t then_jump = this->PushOp(StackVM::RJUMP, 0); + int64_t else_begin = this->GetPC(); + this->SetOperand(else_jump, else_begin - label_ejump); + this->PushOp(StackVM::POP); + this->Push(op->else_case); + int64_t if_end = this->GetPC(); + this->SetOperand(then_jump, if_end - label_then_jump); + } else { + int64_t if_end = this->GetPC(); + this->SetOperand(else_jump, if_end - label_ejump); + this->PushOp(StackVM::POP); + } +} + +void CodeGenStackVM::VisitStmt_(const LetStmt *op) { + this->Push(op->value); + int64_t vid = this->AllocVarID(op->var.get()); + this->PushOp(StackVM::STORE_HEAP, static_cast(vid)); + this->Push(op->body); +} + +void CodeGenStackVM::VisitExpr_(const Ramp *op) { + LOG(FATAL) << "Ramp is not supported"; +} + +void CodeGenStackVM::VisitExpr_(const Broadcast *op) { + LOG(FATAL) << "Broadcast is not supported"; +} + +void CodeGenStackVM::VisitExpr_(const Select *op) { + this->Push(op->true_value); + this->Push(op->false_value); + this->Push(op->condition); + this->PushOp(StackVM::SELECT); +} + +void CodeGenStackVM::VisitStmt_(const AssertStmt *op) { + if (op->message.as()) { + int sid = this->GetStrID(op->message.as()->value); + this->Push(op->condition); + this->PushOp(StackVM::ASSERT, sid); + } +} + +void CodeGenStackVM::VisitStmt_(const AttrStmt *op) { + this->Push(op->body); +} + +void CodeGenStackVM::VisitExpr_(const Let *op) { + this->Push(op->value); + int64_t vid = this->AllocVarID(op->var.get()); + this->PushOp(StackVM::STORE_HEAP, static_cast(vid)); + this->Push(op->body); +} } // namespace codegen } // namespace tvm diff --git a/src/codegen/stack_vm/codegen_stack_vm.h b/src/codegen/stack_vm/codegen_stack_vm.h index 11919bd40955..5788aaa0fc2d 100644 --- a/src/codegen/stack_vm/codegen_stack_vm.h +++ b/src/codegen/stack_vm/codegen_stack_vm.h @@ -7,6 +7,7 @@ #define TVM_CODEGEN_STACK_VM_CODEGEN_STACK_VM_H_ #include +#include #include #include #include @@ -18,12 +19,15 @@ namespace tvm { namespace codegen { +using namespace ir; /*! * \brief A base class to generate a stack VM. * This module is used to generate host wrapper * into device function when only device JIT is available. */ -class CodeGenStackVM { +class CodeGenStackVM + : public ExprFunctor, + public StmtFunctor { public: /*! * \brief Generate a stack VM representing @@ -35,8 +39,10 @@ class CodeGenStackVM { StackVM Compile(LoweredFunc f); /*! \brief Push stmt to generate new code */ void Push(const Stmt& n); - /*! \brief Push expr to generate new code */ - void Push(const Expr& n); + /*! \brief Push expr to generate new code */ + void Push(const Expr& n) { + VisitExpr(n); + } /*! * \brief Push the opcode to the code. * \param opcode The code to be pushed. @@ -84,16 +90,53 @@ class CodeGenStackVM { * \return the heap index of the var. */ int GetVarID(const Variable* v) const; + // Push binary operator + void PushBinary(StackVM::OpCode op_int64, + const Expr& a, + const Expr& b); + // push cast; + void PushCast(Type dst, Type src); // overloadable functions - virtual void Push_(const ir::Load* op); - virtual void Push_(const ir::Store* op); - virtual void Push_(const ir::Allocate* op); - virtual void Push_(const ir::Call* op); - virtual void HandleUnknownCall(const ir::Call* op); - /*! \brief function to to print normal code */ - using FType = IRFunctor; - // vtable to print code - static FType& vtable(); // NOLINT(*) + // expression + void VisitExpr_(const Variable* op) final; + void VisitExpr_(const Load* op) final; + void VisitExpr_(const Let* op) final; + void VisitExpr_(const Call* op) final; + void VisitExpr_(const Add* op) final; + void VisitExpr_(const Sub* op) final; + void VisitExpr_(const Mul* op) final; + void VisitExpr_(const Div* op) final; + void VisitExpr_(const Mod* op) final; + void VisitExpr_(const Min* op) final; + void VisitExpr_(const Max* op) final; + void VisitExpr_(const EQ* op) final; + void VisitExpr_(const NE* op) final; + void VisitExpr_(const LT* op) final; + void VisitExpr_(const LE* op) final; + void VisitExpr_(const GT* op) final; + void VisitExpr_(const GE* op) final; + void VisitExpr_(const And* op) final; + void VisitExpr_(const Or* op) final; + void VisitExpr_(const Cast* op) final; + void VisitExpr_(const Not* op) final; + void VisitExpr_(const Select* op) final; + void VisitExpr_(const Ramp* op) final; + void VisitExpr_(const Broadcast* op) final; + void VisitExpr_(const IntImm* op) final; + void VisitExpr_(const UIntImm* op) final; + void VisitExpr_(const FloatImm* op) final; + void VisitExpr_(const StringImm* op) final; + // statment + void VisitStmt_(const LetStmt* op) final; + void VisitStmt_(const Store* op) final; + void VisitStmt_(const For* op) final; + void VisitStmt_(const IfThenElse* op) final; + void VisitStmt_(const Allocate* op) final; + void VisitStmt_(const AttrStmt* op) final; + void VisitStmt_(const AssertStmt* op) final; + void VisitStmt_(const Evaluate* op) final; + void VisitStmt_(const Block* op) final; + void VisitStmt_(const ProducerConsumer* op) final; private: bool debug_{false};