-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Relay] Higher order reverse mode automatic differentiation that work with control flow #2496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -85,10 +85,10 @@ using ADValue = std::shared_ptr<ADValueNode>; | |
|
|
||
| /*! \brief AD over a program which generates a tensor output. */ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if the program generates a tuple of tensor as output? |
||
| struct ADTensor : ADValueNode { | ||
| Expr foward; | ||
| Expr forward; | ||
| mutable Expr reverse; // must be a variable to avoid duplication | ||
| ADTensor(LetList* ll, const Expr& foward) : | ||
| foward(ll->Push(foward)), reverse(ll->Push(ZeroLike(this->foward))) { } | ||
| ADTensor(LetList* ll, const Expr& forward) : | ||
| forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { } | ||
| }; | ||
|
|
||
| /*! \brief A staged representation of the program, we reflect | ||
|
|
@@ -105,14 +105,14 @@ struct ADFunction : ADValueNode { | |
| func(func) { } | ||
| }; | ||
|
|
||
| struct ReverseAD : ExprFunctor<ADValue(const Expr &)> { | ||
| struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> { | ||
| const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient"); | ||
| std::vector<std::function<void(LetList* ll)>> backprop_actions; | ||
| // we assume no closure so no need for lexical scoping | ||
| std::unordered_map<Var, ADValue, NodeHash, NodeEqual> env; | ||
| LetList* ll; | ||
|
|
||
| ReverseAD(LetList* ll) : ll(ll) { } | ||
| FirstOrderReverseAD(LetList* ll) : ll(ll) { } | ||
|
|
||
| ADValue VisitExpr_(const OpNode* op) final { | ||
| Op op_ref = GetRef<Op>(op); | ||
|
|
@@ -121,21 +121,22 @@ struct ReverseAD : ExprFunctor<ADValue(const Expr &)> { | |
| return std::make_shared<ADFunction>([this, op_ref](const std::vector<ADValue>& args, | ||
| const Attrs& attrs, | ||
| const tvm::Array<Type>& type_args) { | ||
| std::vector<Expr> call_args; | ||
| for (const ADValue& adval : args) { | ||
| call_args.push_back(adval->get<ADTensor>().foward); | ||
| std::vector<Expr> call_args; | ||
| for (const ADValue& adval : args) { | ||
| call_args.push_back(adval->get<ADTensor>().forward); | ||
| } | ||
| auto orig = CallNode::make(op_ref, call_args, attrs, type_args); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to use the real original node instead of a reconstruction? Reconstructing a node may lead to losing some information, e.g. the inferred type
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe, but it will require big change in code structure. if such a case come up i will do it.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sgrechanik-h can i just rerun type infer? right now every pass will destroy checked_type_ and rebuild from type infer.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MarisaKirisame Not sure what you mean, but rerunning type inference sounds like a bit of an overkill, and I'm not sure it can be done before calling the (Also currently I think that in my particular case the proper solution would be to fix the signature of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sgrechanik-h all pass (FuseOps, AD, ANF, GNF, DeadCodeElimination, FoldScaleAxis) remove the type annotation and rerun it AFAIK. I am not sure why it is an AD-specific issue.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MarisaKirisame I think some passes may benefit from using type information, and, of course, they should use it before erasing it (or recreating the node, I don't think
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My other passes use type info too. But we just rerun type infer, and we are encoding (rerunning type infer) into pass manager too. |
||
| auto ret = std::make_shared<ADTensor>(ll, orig); | ||
| backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { | ||
| tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse); | ||
| CHECK(args.size() == rev.size()); | ||
| for (size_t i = 0; i < args.size(); ++i) { | ||
| args[i]->get<ADTensor>().reverse = | ||
| ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i])); | ||
| } | ||
| auto orig = CallNode::make(op_ref, call_args, attrs, type_args); | ||
| auto ret = std::make_shared<ADTensor>(ll, orig); | ||
| backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { | ||
| tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse); | ||
| for (size_t i = 0; i < args.size(); ++i) { | ||
| args[i]->get<ADTensor>().reverse = | ||
| ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i])); | ||
| } | ||
| }); | ||
| return ret; | ||
| }); | ||
| return ret; | ||
| }); | ||
| } | ||
|
|
||
| ADValue VisitExpr_(const ConstantNode* op) final { | ||
|
|
@@ -172,6 +173,23 @@ struct ReverseAD : ExprFunctor<ADValue(const Expr &)> { | |
| } | ||
| }; | ||
|
|
||
| Type GradRetType(const Function& f) { | ||
| // if type annotations are provided, we will construct a ret type; | ||
| // otherwise, leave it to be inferred | ||
| if (!f->ret_type.defined()) { | ||
| return Type(); | ||
| } | ||
| std::vector<Type> vt; | ||
| for (const auto& p : f->params) { | ||
| if (!p->type_annotation.defined()) { | ||
| return Type(); | ||
| } | ||
| vt.push_back(p->type_annotation); | ||
| } | ||
|
|
||
| return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)}); | ||
| } | ||
|
|
||
| Expr FirstOrderGradient(const Expr& re, const Module& mod) { | ||
| // Currently we first remove any global functions for the first | ||
| // order case. | ||
|
|
@@ -182,7 +200,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { | |
|
|
||
| // We will then build a sequence of lets which implement reverse mode. | ||
| Expr body = LetList::With([&](LetList* ll) { | ||
| ReverseAD reverse_ad(ll); | ||
| FirstOrderReverseAD reverse_ad(ll); | ||
| ADValue rev = reverse_ad(e); | ||
| std::vector<ADValue> args; | ||
| for (const auto& p : f->params) { | ||
|
|
@@ -191,46 +209,131 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { | |
| auto c = rev->get<ADFunction>().func(args, Attrs(), {}); | ||
| const auto& res = c->get<ADTensor>(); | ||
| Expr grad = LetList::With([&](LetList* ll) { | ||
| res.reverse = OneLike(res.foward); | ||
| for (auto it = reverse_ad.backprop_actions.rbegin(); | ||
| it != reverse_ad.backprop_actions.rend(); | ||
| ++it) { | ||
| (*it)(ll); | ||
| res.reverse = OnesLike(res.forward); | ||
| for (auto it = reverse_ad.backprop_actions.rbegin(); | ||
| it != reverse_ad.backprop_actions.rend(); | ||
| ++it) { | ||
| (*it)(ll); | ||
| } | ||
| std::vector<Expr> grad_res; | ||
| for (const auto& a : args) { | ||
| grad_res.push_back(a->get<ADTensor>().reverse); | ||
| } | ||
| return TupleNode::make(grad_res); | ||
| }); | ||
| return Pair(res.forward, grad); | ||
| }); | ||
|
|
||
| return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {}); | ||
| } | ||
|
|
||
| TVM_REGISTER_API("relay._ir_pass.first_order_gradient") | ||
| .set_body([](TVMArgs args, TVMRetValue* ret) { | ||
| CHECK_EQ(args.size(), 2); | ||
| *ret = FirstOrderGradient(args[0], args[1]); | ||
| }); | ||
|
|
||
| struct ReverseADType : TypeMutator { | ||
| Type VisitType_(const TensorTypeNode* ttn) final { | ||
| Type t = GetRef<Type>(ttn); | ||
| return TupleTypeNode::make({t, RefTypeNode::make(t)}); | ||
| } | ||
| }; | ||
|
|
||
| struct ReverseAD : ExprMutator { | ||
| Var bp; | ||
| const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient"); | ||
|
|
||
| ReverseAD(const Var& bp) : bp(bp) { } | ||
|
|
||
| Expr VisitExpr_(const OpNode* op) final { | ||
| LOG(FATAL) << "op should only be inside call"; | ||
| throw; | ||
MarisaKirisame marked this conversation as resolved.
Show resolved
Hide resolved
MarisaKirisame marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| Expr VisitExpr_(const CallNode* op) final { | ||
| if (const OpNode* op_node = op->op.as<OpNode>()) { | ||
| Op op_ref = GetRef<Op>(op_node); | ||
| CHECK(rev_map.count(op_ref)) | ||
| << op_node->name << " does not have reverse mode defined"; | ||
| return LetList::With([&](LetList* ll) { | ||
| std::vector<Var> args; | ||
| for (const auto& arg : op->args) { | ||
| args.push_back(ll->Push(VisitExpr(arg))); | ||
| } | ||
| std::vector<Expr> grad_res; | ||
| for (const auto& a : args) { | ||
| grad_res.push_back(a->get<ADTensor>().reverse); | ||
| std::vector<Expr> orig_args; | ||
| for (const auto& arg : args) { | ||
| orig_args.push_back(GetField(VisitExpr(arg), 0)); | ||
| } | ||
| return TupleNode::make(grad_res); | ||
| Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args); | ||
| Var orig_var = ll->Push(orig); | ||
| auto ref = ll->Push(RefCreateNode::make(ZerosLike(orig_var))); | ||
| auto bpv = ll->Push(RefReadNode::make(bp)); | ||
| Expr nbp = FunctionNode::make( | ||
| {}, | ||
| LetList::With([&](LetList* ll) { | ||
| tvm::Array<Expr> rev = rev_map[op_ref](orig, ll->Push(RefReadNode::make(ref))); | ||
| CHECK(args.size() == rev.size()); | ||
| for (size_t i = 0; i < args.size(); ++i) { | ||
| ll->Push(RefWriteNode::make(GetField(args[i], 1), | ||
| Add(ll->Push(RefReadNode::make(GetField(args[i], 1))), | ||
| rev[i]))); | ||
| } | ||
| return CallNode::make(bpv, {}); | ||
| }), | ||
| TupleTypeNode::make({}), | ||
| {}); | ||
| ll->Push(RefWriteNode::make(bp, nbp)); | ||
| return Pair(orig_var, ref); | ||
| }); | ||
| return Pair(res.foward, grad); | ||
| }); | ||
|
|
||
| // if type annotations are provided, we will construct a ret type; | ||
| // otherwise, leave it to be inferred | ||
| Type ret_type = Type(); | ||
| std::vector<Type> vt; | ||
| bool missing = !f->ret_type.defined(); | ||
| for (const auto& p : f->params) { | ||
| if (missing || !p->type_annotation.defined()) { | ||
| missing = true; | ||
| break; | ||
| } | ||
| vt.push_back(p->type_annotation); | ||
| return ExprMutator::VisitExpr_(op); | ||
| } | ||
|
|
||
| Expr VisitExpr_(const ConstantNode* op) final { | ||
| Expr e = GetRef<Expr>(op); | ||
| return Pair(e, RefCreateNode::make(ZerosLike(e))); | ||
| } | ||
|
|
||
| if (!missing) { | ||
| ret_type = TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)}); | ||
| Type VisitType(const Type& t) final { | ||
| return t.defined() ? ReverseADType()(t) : t; | ||
| } | ||
| }; | ||
|
|
||
| return FunctionNode::make(f->params, body, ret_type, {}); | ||
| Expr BPEmpty() { | ||
| Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {}); | ||
| return RefCreateNode::make(unitF); | ||
| } | ||
|
|
||
| TVM_REGISTER_API("relay._ir_pass.first_order_gradient") | ||
| .set_body([](TVMArgs args, TVMRetValue* ret) { | ||
| CHECK_EQ(args.size(), 2); | ||
| *ret = FirstOrderGradient(args[0], args[1]); | ||
| }); | ||
| Expr Gradient(const Expr& re, const Module& mod) { | ||
| auto e = DeGlobal(mod, re); | ||
| auto f = e.as<FunctionNode>(); | ||
| CHECK(f) << "input need to be a function"; | ||
| CHECK(f->type_params.size() == 0) << "no polymorphism supported for now"; | ||
| Expr body = LetList::With([&](LetList* ll) { | ||
| Var bp = ll->Push(BPEmpty()); | ||
| Expr rev = ReverseAD(bp)(e); | ||
| std::vector<Expr> args; | ||
| for (const auto& p : f->params) { | ||
| args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p))))); | ||
| } | ||
| auto c = ll->Push(CallNode::make(rev, args)); | ||
| ll->Push(RefWriteNode::make(GetField(c, 1), OnesLike(GetField(c, 0)))); | ||
| ll->Push(CallNode::make(RefReadNode::make(bp), {})); | ||
| std::vector<Expr> ret; | ||
| for (const auto& a : args) { | ||
| ret.push_back(RefReadNode::make(GetField(a, 1))); | ||
| } | ||
| return Pair(GetField(c, 0), TupleNode::make(ret)); | ||
| }); | ||
| return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {}); | ||
| } | ||
|
|
||
| TVM_REGISTER_API("relay._ir_pass.gradient") | ||
| .set_body([](TVMArgs args, TVMRetValue* ret) { | ||
| CHECK_EQ(args.size(), 2); | ||
| *ret = Gradient(args[0], args[1]); | ||
| }); | ||
|
|
||
| } // namespace relay | ||
| } // namespace tvm | ||
Uh oh!
There was an error while loading. Please reload this page.