@@ -27,6 +27,8 @@ Graph InferAttr(Graph &&ret,
2727 Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
2828 static auto & backward_map =
2929 Op::GetAttr<FBackwardOutToInIndex>(" FBackwardOutToInIndex" );
30+ static auto & backward_in_grad =
31+ Op::GetAttr<FBackwardInGradIndex>(" FBackwardInGradIndex" );
3032 // reshape shape vector
3133 AttrVector rshape;
3234 if (ret.attrs .count (attr_name) != 0 ) {
@@ -54,7 +56,6 @@ Graph InferAttr(Graph &&ret,
5456 }
5557 // Temp space for shape inference.
5658 std::vector<AttrType> ishape, oshape;
57- size_t num_unknown;
5859
5960 // inference step function for nid
6061 auto infer_step = [&](uint32_t nid) {
@@ -76,21 +77,29 @@ Graph InferAttr(Graph &&ret,
7677 } else if (backward_map.count (inode.source ->op ())) {
7778 // Backward operator inference.
7879 CHECK_GE (inode.control_deps .size (), 1 )
79- << " BackwardOp need to have control_deps to its forward op" ;
80+ << " BackwardOp need to have control_deps to its forward op" ;
8081 const IndexedGraph::Node& fnode = idx[inode.control_deps [0 ]];
8182 // Inference the outputs of backward operator (equal to the inputs
8283 // of its corresponding forward operator).
8384 std::vector<uint32_t > out_map =
8485 backward_map[inode.source ->op ()](inode.source ->attrs );
85- bool known = true ;
8686 for (size_t i = 0 ; i < out_map.size (); ++i) {
8787 uint32_t in_id = out_map[i];
8888 CHECK_LT (in_id, fnode.inputs .size ());
8989 rshape[idx.entry_id (nid, i)] =
9090 rshape[idx.entry_id (fnode.inputs [in_id])];
91- if (fis_none (rshape[idx.entry_id (nid, i)])) known = false ;
9291 }
93- num_unknown += !known;
92+ if (backward_in_grad.count (inode.source ->op ())) {
93+ std::vector<uint32_t > in_grad =
94+ backward_in_grad[inode.source ->op ()](inode.source ->attrs );
95+ CHECK_LE (in_grad.size (), fnode.source ->num_outputs ());
96+ for (size_t i = 0 ; i < in_grad.size (); ++i) {
97+ uint32_t eid = idx.entry_id (inode.inputs [in_grad[i]]);
98+ if (fis_none (rshape[eid])) {
99+ rshape[eid] = rshape[idx.entry_id (inode.control_deps [0 ], i)];
100+ }
101+ }
102+ }
94103 } else {
95104 bool forward_known = true ;
96105 // Forward operator inference.
@@ -112,7 +121,6 @@ Graph InferAttr(Graph &&ret,
112121 // Call inference function of the operator.
113122 forward_known = finfer (inode.source ->attrs , &ishape, &oshape);
114123 }
115- num_unknown += !forward_known;
116124 // Save to the result map.
117125 for (uint32_t i = 0 ; i < num_inputs; ++i) {
118126 rshape[idx.entry_id (inode.inputs [i])] = ishape[i];
@@ -123,16 +131,24 @@ Graph InferAttr(Graph &&ret,
123131 }
124132 };
125133
126- num_unknown = 0 ;
127- for (uint32_t nid = 0 ; nid < idx.num_nodes (); ++nid) {
128- infer_step (nid);
129- }
130- if (num_unknown != 0 ) {
134+ size_t num_unknown = 0 ;
135+ const int kMaxStep = 3 ;
136+ for (int i = 0 ; i < kMaxStep ; ++i) {
137+ if (i % 2 == 0 ) {
138+ for (uint32_t nid = 0 ; nid < idx.num_nodes (); ++nid) {
139+ infer_step (nid);
140+ }
141+ } else {
142+ // backward inference
143+ for (uint32_t i = idx.num_nodes (); i != 0 ; --i) {
144+ infer_step (i - 1 );
145+ }
146+ }
131147 num_unknown = 0 ;
132- // backward inference
133- for (uint32_t i = idx.num_nodes (); i != 0 ; --i) {
134- infer_step (i - 1 );
148+ for (size_t i = 0 ; i < idx.num_node_entries (); ++i) {
149+ if (fis_none (rshape[i])) ++num_unknown;
135150 }
151+ if (num_unknown == 0 ) break ;
136152 }
137153 // set the shapes
138154 ret.attrs [attr_name] = std::make_shared<any>(std::move (rshape));
0 commit comments