@@ -38,6 +38,7 @@ struct GradEntry {
3838 NodeEntry sum{nullptr , 0 , 0 };
3939#endif
4040 std::vector<NodeEntry> grads;
41+ bool need_attr_hint{true };
4142};
4243
4344Graph Gradient (Graph src) {
@@ -85,9 +86,6 @@ Graph Gradient(Graph src) {
8586 CHECK_EQ (ys.size (), ys_out_grad.size ());
8687 for (size_t i = 0 ; i < ys.size (); ++i) {
8788 NodeEntry ograd = ys_out_grad[i];
88- if (attr_hint_fun != nullptr ) {
89- ograd = attr_hint_fun (ograd, ys[i]);
90- }
9189 output_grads[ys[i].node .get ()][ys[i].index ].grads = { ograd };
9290 }
9391
@@ -121,27 +119,29 @@ Graph Gradient(Graph src) {
121119 const NodePtr& ptr = *rit;
122120 if (ptr->is_variable ()) continue ;
123121 out_agg_grads.clear ();
124- for (GradEntry& e : output_grads.at (ptr.get ())) {
122+ auto & out_grad_vec = output_grads.at (ptr.get ());
123+ for (uint32_t i = 0 ; i < out_grad_vec.size (); ++i) {
124+ GradEntry& e = out_grad_vec[i];
125125 e.sum = agg_fun (std::move (e.grads ));
126+ if (e.need_attr_hint && attr_hint_fun != nullptr ) {
127+ e.sum = attr_hint_fun (e.sum , NodeEntry{ptr, 0 , i});
128+ }
126129 out_agg_grads.push_back (e.sum );
127130 }
128131 if ((*rit)->inputs .size () != 0 ) {
129132 NodePtr fwd_node = (mirror_map.size () == 0 ? ptr : mirror_map.at (ptr.get ()));
130133 std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op ()](
131134 fwd_node, out_agg_grads);
132-
133- if (attr_hint_fun != nullptr ) {
134- // only insert hint when shape inference function is not available.
135- for (size_t i = 0 ; i < input_grads.size (); ++i) {
136- if (finfer_shape.count (input_grads[i].node ->op ())) continue ;
137- input_grads[i] = attr_hint_fun (input_grads[i], fwd_node->inputs [i]);
138- }
139- }
140135 CHECK_EQ ((*rit)->inputs .size (), input_grads.size ())
141136 << " Gradient function not returning enough gradient" ;
142137 auto git = input_grads.begin ();
143138 for (auto it = (*rit)->inputs .begin (); it != (*rit)->inputs .end (); ++it, ++git) {
144- output_grads[it->node .get ()][it->index ].grads .emplace_back (std::move (*git));
139+ auto & ge = output_grads[it->node .get ()][it->index ];
140+ // if any of the backward op can do shape inference, the hint is not necessary.
141+ if (finfer_shape.count (git->node ->op ())) {
142+ ge.need_attr_hint = false ;
143+ }
144+ ge.grads .emplace_back (std::move (*git));
145145 }
146146 }
147147 }
@@ -153,6 +153,9 @@ Graph Gradient(Graph src) {
153153 // aggregate sum if there haven't been
154154 if (entry.sum .node .get () == nullptr ) {
155155 entry.sum = agg_fun (std::move (entry.grads ));
156+ if (entry.need_attr_hint && attr_hint_fun != nullptr ) {
157+ entry.sum = attr_hint_fun (entry.sum , e);
158+ }
156159 }
157160 ret.outputs .emplace_back (std::move (entry.sum ));
158161 }
0 commit comments