Skip to content

Commit c82fbdb

Browse files
address comment
1 parent 085457a commit c82fbdb

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/relay/pass/gradient.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@ using ADValue = std::shared_ptr<ADValueNode>;
8585

8686
/*! \brief AD over a program which generates a tensor output. */
8787
struct ADTensor : ADValueNode {
88-
Expr foward;
88+
Expr forward;
8989
mutable Expr reverse; // must be a variable to avoid duplication
90-
ADTensor(LetList* ll, const Expr& foward) :
91-
foward(ll->Push(foward)), reverse(ll->Push(ZeroLike(this->foward))) { }
90+
ADTensor(LetList* ll, const Expr& forward) :
91+
forward(ll->Push(forward)), reverse(ll->Push(ZeroLike(this->forward))) { }
9292
};
9393

9494
/*! \brief A staged representation of the program, we reflect
@@ -123,7 +123,7 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> {
123123
const tvm::Array<Type>& type_args) {
124124
std::vector<Expr> call_args;
125125
for (const ADValue& adval : args) {
126-
call_args.push_back(adval->get<ADTensor>().foward);
126+
call_args.push_back(adval->get<ADTensor>().forward);
127127
}
128128
auto orig = CallNode::make(op_ref, call_args, attrs, type_args);
129129
auto ret = std::make_shared<ADTensor>(ll, orig);
@@ -209,7 +209,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
209209
auto c = rev->get<ADFunction>().func(args, Attrs(), {});
210210
const auto& res = c->get<ADTensor>();
211211
Expr grad = LetList::With([&](LetList* ll) {
212-
res.reverse = OneLike(res.foward);
212+
res.reverse = OneLike(res.forward);
213213
for (auto it = reverse_ad.backprop_actions.rbegin();
214214
it != reverse_ad.backprop_actions.rend();
215215
++it) {
@@ -221,7 +221,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
221221
}
222222
return TupleNode::make(grad_res);
223223
});
224-
return Pair(res.foward, grad);
224+
return Pair(res.forward, grad);
225225
});
226226

227227
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
@@ -247,7 +247,7 @@ struct ReverseAD : ExprMutator {
247247
ReverseAD(const Var& bp) : bp(bp) { }
248248

249249
Expr VisitExpr_(const OpNode* op) final {
250-
CHECK(false) << "op should only be inside call";
250+
LOG(FATAL) << "op should only be inside call";
251251
throw;
252252
}
253253

0 commit comments

Comments
 (0)