Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.

Commit 8ce0930

Browse files
MarisaKirisamebwasti
authored andcommitted
[Relay] Higher order reverse mode automatic differentiation that work with control flow (apache#2496)
add test remove dead code stash do it add more test
1 parent 156114f commit 8ce0930

File tree

6 files changed

+243
-56
lines changed

6 files changed

+243
-56
lines changed

python/tvm/relay/ir_pass.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -530,9 +530,11 @@ def to_graph_normal_form(expr):
530530
return _ir_pass.to_graph_normal_form(expr)
531531

532532

533-
def gradient(expr, mod=None):
533+
def gradient(expr, mod=None, mode='higher_order'):
534534
"""
535-
Transform a function to return original result paired with gradient of input.
535+
Transform the input function,
536+
returning a function that calculate the original result,
537+
paired with gradient of the input.
536538
537539
Parameters
538540
----------
@@ -541,12 +543,23 @@ def gradient(expr, mod=None):
541543
542544
mod : Optional[tvm.relay.Module]
543545
546+
mode : Optional[String]
547+
The mode of the automatic differentiation algorithm.
548+
'first_order' only work on first order code, but will not produce reference nor closure.
549+
'higher_order' work on all code using reference and closure.
550+
544551
Returns
545552
-------
546553
expr : tvm.relay.Expr
547-
The output expression.
554+
The transformed expression.
548555
"""
549-
return _ir_pass.first_order_gradient(expr, mod)
556+
if mode == 'first_order':
557+
return _ir_pass.first_order_gradient(expr, mod)
558+
elif mode == 'higher_order':
559+
return _ir_pass.gradient(expr, mod)
560+
else:
561+
raise Exception('unknown mode')
562+
550563

551564

552565
def get_total_mac_number(expr):

src/relay/pass/fuse_ops.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
225225
}
226226

227227
node->pattern = op_pattern;
228+
this->Update(call->op, nullptr, kOpaque);
228229
const auto* rtype = call->checked_type().as<TensorTypeNode>();
229230
// pass the message back to all the children it references.
230231
for (size_t i = 0; i < call->args.size(); ++i) {

src/relay/pass/gradient.cc

Lines changed: 152 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@ using ADValue = std::shared_ptr<ADValueNode>;
8585

8686
/*! \brief AD over a program which generates a tensor output. */
8787
struct ADTensor : ADValueNode {
88-
Expr foward;
88+
Expr forward;
8989
mutable Expr reverse; // must be a variable to avoid duplication
90-
ADTensor(LetList* ll, const Expr& foward) :
91-
foward(ll->Push(foward)), reverse(ll->Push(ZeroLike(this->foward))) { }
90+
ADTensor(LetList* ll, const Expr& forward) :
91+
forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { }
9292
};
9393

9494
/*! \brief A staged representation of the program, we reflect
@@ -105,14 +105,14 @@ struct ADFunction : ADValueNode {
105105
func(func) { }
106106
};
107107

108-
struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
108+
struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> {
109109
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
110110
std::vector<std::function<void(LetList* ll)>> backprop_actions;
111111
// we assume no closure so no need for lexical scoping
112112
std::unordered_map<Var, ADValue, NodeHash, NodeEqual> env;
113113
LetList* ll;
114114

115-
ReverseAD(LetList* ll) : ll(ll) { }
115+
FirstOrderReverseAD(LetList* ll) : ll(ll) { }
116116

117117
ADValue VisitExpr_(const OpNode* op) final {
118118
Op op_ref = GetRef<Op>(op);
@@ -121,21 +121,22 @@ struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
121121
return std::make_shared<ADFunction>([this, op_ref](const std::vector<ADValue>& args,
122122
const Attrs& attrs,
123123
const tvm::Array<Type>& type_args) {
124-
std::vector<Expr> call_args;
125-
for (const ADValue& adval : args) {
126-
call_args.push_back(adval->get<ADTensor>().foward);
124+
std::vector<Expr> call_args;
125+
for (const ADValue& adval : args) {
126+
call_args.push_back(adval->get<ADTensor>().forward);
127+
}
128+
auto orig = CallNode::make(op_ref, call_args, attrs, type_args);
129+
auto ret = std::make_shared<ADTensor>(ll, orig);
130+
backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
131+
tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
132+
CHECK(args.size() == rev.size());
133+
for (size_t i = 0; i < args.size(); ++i) {
134+
args[i]->get<ADTensor>().reverse =
135+
ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i]));
127136
}
128-
auto orig = CallNode::make(op_ref, call_args, attrs, type_args);
129-
auto ret = std::make_shared<ADTensor>(ll, orig);
130-
backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
131-
tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
132-
for (size_t i = 0; i < args.size(); ++i) {
133-
args[i]->get<ADTensor>().reverse =
134-
ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i]));
135-
}
136-
});
137-
return ret;
138137
});
138+
return ret;
139+
});
139140
}
140141

141142
ADValue VisitExpr_(const ConstantNode* op) final {
@@ -172,6 +173,23 @@ struct ReverseAD : ExprFunctor<ADValue(const Expr &)> {
172173
}
173174
};
174175

176+
Type GradRetType(const Function& f) {
177+
// if type annotations are provided, we will construct a ret type;
178+
// otherwise, leave it to be inferred
179+
if (!f->ret_type.defined()) {
180+
return Type();
181+
}
182+
std::vector<Type> vt;
183+
for (const auto& p : f->params) {
184+
if (!p->type_annotation.defined()) {
185+
return Type();
186+
}
187+
vt.push_back(p->type_annotation);
188+
}
189+
190+
return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
191+
}
192+
175193
Expr FirstOrderGradient(const Expr& re, const Module& mod) {
176194
// Currently we first remove any global functions for the first
177195
// order case.
@@ -182,7 +200,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
182200

183201
// We will then build a sequence of lets which implement reverse mode.
184202
Expr body = LetList::With([&](LetList* ll) {
185-
ReverseAD reverse_ad(ll);
203+
FirstOrderReverseAD reverse_ad(ll);
186204
ADValue rev = reverse_ad(e);
187205
std::vector<ADValue> args;
188206
for (const auto& p : f->params) {
@@ -191,46 +209,131 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
191209
auto c = rev->get<ADFunction>().func(args, Attrs(), {});
192210
const auto& res = c->get<ADTensor>();
193211
Expr grad = LetList::With([&](LetList* ll) {
194-
res.reverse = OneLike(res.foward);
195-
for (auto it = reverse_ad.backprop_actions.rbegin();
196-
it != reverse_ad.backprop_actions.rend();
197-
++it) {
198-
(*it)(ll);
212+
res.reverse = OnesLike(res.forward);
213+
for (auto it = reverse_ad.backprop_actions.rbegin();
214+
it != reverse_ad.backprop_actions.rend();
215+
++it) {
216+
(*it)(ll);
217+
}
218+
std::vector<Expr> grad_res;
219+
for (const auto& a : args) {
220+
grad_res.push_back(a->get<ADTensor>().reverse);
221+
}
222+
return TupleNode::make(grad_res);
223+
});
224+
return Pair(res.forward, grad);
225+
});
226+
227+
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
228+
}
229+
230+
TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
231+
.set_body([](TVMArgs args, TVMRetValue* ret) {
232+
CHECK_EQ(args.size(), 2);
233+
*ret = FirstOrderGradient(args[0], args[1]);
234+
});
235+
236+
struct ReverseADType : TypeMutator {
237+
Type VisitType_(const TensorTypeNode* ttn) final {
238+
Type t = GetRef<Type>(ttn);
239+
return TupleTypeNode::make({t, RefTypeNode::make(t)});
240+
}
241+
};
242+
243+
struct ReverseAD : ExprMutator {
244+
Var bp;
245+
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
246+
247+
ReverseAD(const Var& bp) : bp(bp) { }
248+
249+
Expr VisitExpr_(const OpNode* op) final {
250+
LOG(FATAL) << "op should only be inside call";
251+
throw;
252+
}
253+
254+
Expr VisitExpr_(const CallNode* op) final {
255+
if (const OpNode* op_node = op->op.as<OpNode>()) {
256+
Op op_ref = GetRef<Op>(op_node);
257+
CHECK(rev_map.count(op_ref))
258+
<< op_node->name << " does not have reverse mode defined";
259+
return LetList::With([&](LetList* ll) {
260+
std::vector<Var> args;
261+
for (const auto& arg : op->args) {
262+
args.push_back(ll->Push(VisitExpr(arg)));
199263
}
200-
std::vector<Expr> grad_res;
201-
for (const auto& a : args) {
202-
grad_res.push_back(a->get<ADTensor>().reverse);
264+
std::vector<Expr> orig_args;
265+
for (const auto& arg : args) {
266+
orig_args.push_back(GetField(VisitExpr(arg), 0));
203267
}
204-
return TupleNode::make(grad_res);
268+
Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args);
269+
Var orig_var = ll->Push(orig);
270+
auto ref = ll->Push(RefCreateNode::make(ZerosLike(orig_var)));
271+
auto bpv = ll->Push(RefReadNode::make(bp));
272+
Expr nbp = FunctionNode::make(
273+
{},
274+
LetList::With([&](LetList* ll) {
275+
tvm::Array<Expr> rev = rev_map[op_ref](orig, ll->Push(RefReadNode::make(ref)));
276+
CHECK(args.size() == rev.size());
277+
for (size_t i = 0; i < args.size(); ++i) {
278+
ll->Push(RefWriteNode::make(GetField(args[i], 1),
279+
Add(ll->Push(RefReadNode::make(GetField(args[i], 1))),
280+
rev[i])));
281+
}
282+
return CallNode::make(bpv, {});
283+
}),
284+
TupleTypeNode::make({}),
285+
{});
286+
ll->Push(RefWriteNode::make(bp, nbp));
287+
return Pair(orig_var, ref);
205288
});
206-
return Pair(res.foward, grad);
207-
});
208-
209-
// if type annotations are provided, we will construct a ret type;
210-
// otherwise, leave it to be inferred
211-
Type ret_type = Type();
212-
std::vector<Type> vt;
213-
bool missing = !f->ret_type.defined();
214-
for (const auto& p : f->params) {
215-
if (missing || !p->type_annotation.defined()) {
216-
missing = true;
217-
break;
218289
}
219-
vt.push_back(p->type_annotation);
290+
return ExprMutator::VisitExpr_(op);
291+
}
292+
293+
Expr VisitExpr_(const ConstantNode* op) final {
294+
Expr e = GetRef<Expr>(op);
295+
return Pair(e, RefCreateNode::make(ZerosLike(e)));
220296
}
221297

222-
if (!missing) {
223-
ret_type = TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
298+
Type VisitType(const Type& t) final {
299+
return t.defined() ? ReverseADType()(t) : t;
224300
}
301+
};
225302

226-
return FunctionNode::make(f->params, body, ret_type, {});
303+
Expr BPEmpty() {
304+
Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
305+
return RefCreateNode::make(unitF);
227306
}
228307

229-
TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
230-
.set_body([](TVMArgs args, TVMRetValue* ret) {
231-
CHECK_EQ(args.size(), 2);
232-
*ret = FirstOrderGradient(args[0], args[1]);
233-
});
308+
Expr Gradient(const Expr& re, const Module& mod) {
309+
auto e = DeGlobal(mod, re);
310+
auto f = e.as<FunctionNode>();
311+
CHECK(f) << "input need to be a function";
312+
CHECK(f->type_params.size() == 0) << "no polymorphism supported for now";
313+
Expr body = LetList::With([&](LetList* ll) {
314+
Var bp = ll->Push(BPEmpty());
315+
Expr rev = ReverseAD(bp)(e);
316+
std::vector<Expr> args;
317+
for (const auto& p : f->params) {
318+
args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p)))));
319+
}
320+
auto c = ll->Push(CallNode::make(rev, args));
321+
ll->Push(RefWriteNode::make(GetField(c, 1), OnesLike(GetField(c, 0))));
322+
ll->Push(CallNode::make(RefReadNode::make(bp), {}));
323+
std::vector<Expr> ret;
324+
for (const auto& a : args) {
325+
ret.push_back(RefReadNode::make(GetField(a, 1)));
326+
}
327+
return Pair(GetField(c, 0), TupleNode::make(ret));
328+
});
329+
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
330+
}
331+
332+
TVM_REGISTER_API("relay._ir_pass.gradient")
333+
.set_body([](TVMArgs args, TVMRetValue* ret) {
334+
CHECK_EQ(args.size(), 2);
335+
*ret = Gradient(args[0], args[1]);
336+
});
234337

235338
} // namespace relay
236339
} // namespace tvm

src/relay/pass/pattern_util.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,12 @@ inline Expr Divide(Expr lhs, Expr rhs) {
299299
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
300300
}
301301

302-
inline Expr ZeroLike(Expr e) {
302+
inline Expr ZerosLike(Expr e) {
303303
static const Op& op = Op::Get("zeros_like");
304304
return CallNode::make(op, {e});
305305
}
306306

307-
inline Expr OneLike(Expr e) {
307+
inline Expr OnesLike(Expr e) {
308308
static const Op& op = Op::Get("ones_like");
309309
return CallNode::make(op, {e});
310310
}

src/relay/pass/type_infer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ bool TupleGetItemRel(const Array<Type>& types,
5353
const auto* param = attrs.as<TupleGetItemAttrs>();
5454
CHECK(param != nullptr);
5555
CHECK_GE(param->index, 0);
56-
CHECK_LT(param->index, data->fields.size());
56+
CHECK_LT(param->index, data->fields.size());
5757
reporter->Assign(types[1], data->fields[param->index]);
5858
return true;
5959
}

0 commit comments

Comments
 (0)