@@ -85,10 +85,10 @@ using ADValue = std::shared_ptr<ADValueNode>;
8585
8686/* ! \brief AD over a program which generates a tensor output. */
8787struct ADTensor : ADValueNode {
88- Expr foward ;
88+ Expr forward ;
8989 mutable Expr reverse; // must be a variable to avoid duplication
90- ADTensor (LetList* ll, const Expr& foward ) :
91- foward (ll->Push (foward )), reverse(ll->Push (ZeroLike (this ->foward ))) { }
90+ ADTensor (LetList* ll, const Expr& forward ) :
91+ forward (ll->Push (forward )), reverse(ll->Push (ZerosLike (this ->forward ))) { }
9292};
9393
9494/* ! \brief A staged representation of the program, we reflect
@@ -105,14 +105,14 @@ struct ADFunction : ADValueNode {
105105 func(func) { }
106106};
107107
108- struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
108+ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> {
109109 const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>(" FPrimalGradient" );
110110 std::vector<std::function<void (LetList* ll)>> backprop_actions;
111111 // we assume no closure so no need for lexical scoping
112112 std::unordered_map<Var, ADValue, NodeHash, NodeEqual> env;
113113 LetList* ll;
114114
115- ReverseAD (LetList* ll) : ll(ll) { }
115+ FirstOrderReverseAD (LetList* ll) : ll(ll) { }
116116
117117 ADValue VisitExpr_ (const OpNode* op) final {
118118 Op op_ref = GetRef<Op>(op);
@@ -121,21 +121,22 @@ struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
121121 return std::make_shared<ADFunction>([this , op_ref](const std::vector<ADValue>& args,
122122 const Attrs& attrs,
123123 const tvm::Array<Type>& type_args) {
124- std::vector<Expr> call_args;
125- for (const ADValue& adval : args) {
126- call_args.push_back (adval->get <ADTensor>().foward );
124+ std::vector<Expr> call_args;
125+ for (const ADValue& adval : args) {
126+ call_args.push_back (adval->get <ADTensor>().forward );
127+ }
128+ auto orig = CallNode::make (op_ref, call_args, attrs, type_args);
129+ auto ret = std::make_shared<ADTensor>(ll, orig);
130+ backprop_actions.push_back ([this , args, orig, ret, op_ref](LetList* ll) {
131+ tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse );
132+ CHECK (args.size () == rev.size ());
133+ for (size_t i = 0 ; i < args.size (); ++i) {
134+ args[i]->get <ADTensor>().reverse =
135+ ll->Push (Add (args[i]->get <ADTensor>().reverse , rev[i]));
127136 }
128- auto orig = CallNode::make (op_ref, call_args, attrs, type_args);
129- auto ret = std::make_shared<ADTensor>(ll, orig);
130- backprop_actions.push_back ([this , args, orig, ret, op_ref](LetList* ll) {
131- tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse );
132- for (size_t i = 0 ; i < args.size (); ++i) {
133- args[i]->get <ADTensor>().reverse =
134- ll->Push (Add (args[i]->get <ADTensor>().reverse , rev[i]));
135- }
136- });
137- return ret;
138137 });
138+ return ret;
139+ });
139140 }
140141
141142 ADValue VisitExpr_ (const ConstantNode* op) final {
@@ -172,6 +173,23 @@ struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
172173 }
173174};
174175
176+ Type GradRetType (const Function& f) {
177+ // if type annotations are provided, we will construct a ret type;
178+ // otherwise, leave it to be inferred
179+ if (!f->ret_type .defined ()) {
180+ return Type ();
181+ }
182+ std::vector<Type> vt;
183+ for (const auto & p : f->params ) {
184+ if (!p->type_annotation .defined ()) {
185+ return Type ();
186+ }
187+ vt.push_back (p->type_annotation );
188+ }
189+
190+ return TupleTypeNode::make ({f->ret_type , TupleTypeNode::make (vt)});
191+ }
192+
175193Expr FirstOrderGradient (const Expr& re, const Module& mod) {
176194 // Currently we first remove any global functions for the first
177195 // order case.
@@ -182,7 +200,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
182200
183201 // We will then build a sequence of lets which implement reverse mode.
184202 Expr body = LetList::With ([&](LetList* ll) {
185- ReverseAD reverse_ad (ll);
203+ FirstOrderReverseAD reverse_ad (ll);
186204 ADValue rev = reverse_ad (e);
187205 std::vector<ADValue> args;
188206 for (const auto & p : f->params ) {
@@ -191,46 +209,131 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
191209 auto c = rev->get <ADFunction>().func (args, Attrs (), {});
192210 const auto & res = c->get <ADTensor>();
193211 Expr grad = LetList::With ([&](LetList* ll) {
194- res.reverse = OneLike (res.foward );
195- for (auto it = reverse_ad.backprop_actions .rbegin ();
196- it != reverse_ad.backprop_actions .rend ();
197- ++it) {
198- (*it)(ll);
212+ res.reverse = OnesLike (res.forward );
213+ for (auto it = reverse_ad.backprop_actions .rbegin ();
214+ it != reverse_ad.backprop_actions .rend ();
215+ ++it) {
216+ (*it)(ll);
217+ }
218+ std::vector<Expr> grad_res;
219+ for (const auto & a : args) {
220+ grad_res.push_back (a->get <ADTensor>().reverse );
221+ }
222+ return TupleNode::make (grad_res);
223+ });
224+ return Pair (res.forward , grad);
225+ });
226+
227+ return FunctionNode::make (f->params , body, GradRetType (GetRef<Function>(f)), {});
228+ }
229+
230+ TVM_REGISTER_API (" relay._ir_pass.first_order_gradient" )
231+ .set_body([](TVMArgs args, TVMRetValue* ret) {
232+ CHECK_EQ (args.size (), 2 );
233+ *ret = FirstOrderGradient (args[0 ], args[1 ]);
234+ });
235+
236+ struct ReverseADType : TypeMutator {
237+ Type VisitType_ (const TensorTypeNode* ttn) final {
238+ Type t = GetRef<Type>(ttn);
239+ return TupleTypeNode::make ({t, RefTypeNode::make (t)});
240+ }
241+ };
242+
243+ struct ReverseAD : ExprMutator {
244+ Var bp;
245+ const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>(" FPrimalGradient" );
246+
247+ ReverseAD (const Var& bp) : bp(bp) { }
248+
249+ Expr VisitExpr_ (const OpNode* op) final {
250+ LOG (FATAL) << " op should only be inside call" ;
251+ throw ;
252+ }
253+
254+ Expr VisitExpr_ (const CallNode* op) final {
255+ if (const OpNode* op_node = op->op .as <OpNode>()) {
256+ Op op_ref = GetRef<Op>(op_node);
257+ CHECK (rev_map.count (op_ref))
258+ << op_node->name << " does not have reverse mode defined" ;
259+ return LetList::With ([&](LetList* ll) {
260+ std::vector<Var> args;
261+ for (const auto & arg : op->args ) {
262+ args.push_back (ll->Push (VisitExpr (arg)));
199263 }
200- std::vector<Expr> grad_res ;
201- for (const auto & a : args) {
202- grad_res .push_back (a-> get <ADTensor>(). reverse );
264+ std::vector<Expr> orig_args ;
265+ for (const auto & arg : args) {
266+ orig_args .push_back (GetField ( VisitExpr (arg), 0 ) );
203267 }
204- return TupleNode::make (grad_res);
268+ Expr orig = CallNode::make (op->op , orig_args, op->attrs , op->type_args );
269+ Var orig_var = ll->Push (orig);
270+ auto ref = ll->Push (RefCreateNode::make (ZerosLike (orig_var)));
271+ auto bpv = ll->Push (RefReadNode::make (bp));
272+ Expr nbp = FunctionNode::make (
273+ {},
274+ LetList::With ([&](LetList* ll) {
275+ tvm::Array<Expr> rev = rev_map[op_ref](orig, ll->Push (RefReadNode::make (ref)));
276+ CHECK (args.size () == rev.size ());
277+ for (size_t i = 0 ; i < args.size (); ++i) {
278+ ll->Push (RefWriteNode::make (GetField (args[i], 1 ),
279+ Add (ll->Push (RefReadNode::make (GetField (args[i], 1 ))),
280+ rev[i])));
281+ }
282+ return CallNode::make (bpv, {});
283+ }),
284+ TupleTypeNode::make ({}),
285+ {});
286+ ll->Push (RefWriteNode::make (bp, nbp));
287+ return Pair (orig_var, ref);
205288 });
206- return Pair (res.foward , grad);
207- });
208-
209- // if type annotations are provided, we will construct a ret type;
210- // otherwise, leave it to be inferred
211- Type ret_type = Type ();
212- std::vector<Type> vt;
213- bool missing = !f->ret_type .defined ();
214- for (const auto & p : f->params ) {
215- if (missing || !p->type_annotation .defined ()) {
216- missing = true ;
217- break ;
218289 }
219- vt.push_back (p->type_annotation );
290+ return ExprMutator::VisitExpr_ (op);
291+ }
292+
293+ Expr VisitExpr_ (const ConstantNode* op) final {
294+ Expr e = GetRef<Expr>(op);
295+ return Pair (e, RefCreateNode::make (ZerosLike (e)));
220296 }
221297
222- if (!missing) {
223- ret_type = TupleTypeNode::make ({f-> ret_type , TupleTypeNode::make (vt)}) ;
298+ Type VisitType ( const Type& t) final {
299+ return t. defined () ? ReverseADType ()(t) : t ;
224300 }
301+ };
225302
226- return FunctionNode::make (f->params , body, ret_type, {});
303+ Expr BPEmpty () {
304+ Expr unitF = FunctionNode::make ({}, TupleNode::make ({}), TupleTypeNode::make ({}), {});
305+ return RefCreateNode::make (unitF);
227306}
228307
229- TVM_REGISTER_API (" relay._ir_pass.first_order_gradient" )
230- .set_body([](TVMArgs args, TVMRetValue* ret) {
231- CHECK_EQ (args.size (), 2 );
232- *ret = FirstOrderGradient (args[0 ], args[1 ]);
233- });
308+ Expr Gradient (const Expr& re, const Module& mod) {
309+ auto e = DeGlobal (mod, re);
310+ auto f = e.as <FunctionNode>();
311+ CHECK (f) << " input need to be a function" ;
312+ CHECK (f->type_params .size () == 0 ) << " no polymorphism supported for now" ;
313+ Expr body = LetList::With ([&](LetList* ll) {
314+ Var bp = ll->Push (BPEmpty ());
315+ Expr rev = ReverseAD (bp)(e);
316+ std::vector<Expr> args;
317+ for (const auto & p : f->params ) {
318+ args.push_back (ll->Push (Pair (p, RefCreateNode::make (ZerosLike (p)))));
319+ }
320+ auto c = ll->Push (CallNode::make (rev, args));
321+ ll->Push (RefWriteNode::make (GetField (c, 1 ), OnesLike (GetField (c, 0 ))));
322+ ll->Push (CallNode::make (RefReadNode::make (bp), {}));
323+ std::vector<Expr> ret;
324+ for (const auto & a : args) {
325+ ret.push_back (RefReadNode::make (GetField (a, 1 )));
326+ }
327+ return Pair (GetField (c, 0 ), TupleNode::make (ret));
328+ });
329+ return FunctionNode::make (f->params , body, GradRetType (GetRef<Function>(f)), {});
330+ }
331+
332+ TVM_REGISTER_API (" relay._ir_pass.gradient" )
333+ .set_body([](TVMArgs args, TVMRetValue* ret) {
334+ CHECK_EQ (args.size (), 2 );
335+ *ret = Gradient (args[0 ], args[1 ]);
336+ });
234337
235338} // namespace relay
236339} // namespace tvm
0 commit comments