@@ -11,15 +11,16 @@ namespace nnvm {
1111namespace pass {
1212namespace {
1313
14- template <typename AttrType, typename IsNone>
14+ template <typename AttrType, typename IsNone, typename FDefault >
1515Graph InferAttr (Graph &&ret,
16- const AttrType default_val ,
16+ const AttrType empty_val ,
1717 const char * infer_name,
1818 const char * input_name,
1919 const char * attr_key_name,
2020 const char * attr_name,
2121 const char * unknown_name,
22- IsNone fis_none) {
22+ IsNone fis_none,
23+ FDefault fdefault) {
2324 using AttrVector = std::vector<AttrType>;
2425 const IndexedGraph& idx = ret.indexed_graph ();
2526 static auto & finfer_shape =
@@ -31,7 +32,7 @@ Graph InferAttr(Graph &&ret,
3132 if (ret.attrs .count (attr_name) != 0 ) {
3233 rshape = ret.MoveCopyAttr <AttrVector>(attr_name);
3334 } else {
34- rshape.resize (idx.num_node_entries (), default_val );
35+ rshape.resize (idx.num_node_entries (), empty_val );
3536 }
3637
3738 if (ret.attrs .count (input_name) != 0 ) {
@@ -51,12 +52,12 @@ Graph InferAttr(Graph &&ret,
5152 // erase the provided arguments
5253 ret.attrs .erase (attr_key_name);
5354 }
54-
5555 // Temp space for shape inference.
5656 std::vector<AttrType> ishape, oshape;
57- // number of completed nodes
58- size_t num_unknown = 0 ;
59- for (uint32_t nid = 0 ; nid < idx.num_nodes (); ++nid) {
57+ size_t num_unknown;
58+
59+ // inference step function for nid
60+ auto infer_step = [&](uint32_t nid) {
6061 const auto & inode = idx[nid];
6162 const uint32_t num_inputs = inode.inputs .size ();
6263 const uint32_t num_outputs = inode.source ->num_outputs ();
@@ -72,27 +73,6 @@ Graph InferAttr(Graph &&ret,
7273 CHECK (is >> rshape[out_ent_id]) << " Invalid attribute" ;
7374 }
7475 }
75- } else if (finfer_shape.count (inode.source ->op ())) {
76- // Forward operator inference.
77- ishape.resize (num_inputs, default_val);
78- for (uint32_t i = 0 ; i < ishape.size (); ++i) {
79- ishape[i] = rshape[idx.entry_id (inode.inputs [i])];
80- }
81- oshape.resize (num_outputs, default_val);
82- for (uint32_t i = 0 ; i < oshape.size (); ++i) {
83- oshape[i] = rshape[idx.entry_id (nid, i)];
84- }
85- // Call inference function of the operator.
86- bool forward_known = finfer_shape[inode.source ->op ()](
87- inode.source ->attrs , &ishape, &oshape);
88- num_unknown += !forward_known;
89- // Save to the result map.
90- for (uint32_t i = 0 ; i < num_inputs; ++i) {
91- rshape[idx.entry_id (inode.inputs [i])] = ishape[i];
92- }
93- for (uint32_t i = 0 ; i < num_outputs; ++i) {
94- rshape[idx.entry_id (nid, i)] = oshape[i];
95- }
9676 } else if (backward_map.count (inode.source ->op ())) {
9777 // Backward operator inference.
9878 CHECK_GE (inode.control_deps .size (), 1 )
@@ -111,6 +91,47 @@ Graph InferAttr(Graph &&ret,
11191 if (fis_none (rshape[idx.entry_id (nid, i)])) known = false ;
11292 }
11393 num_unknown += !known;
94+ } else {
95+ bool forward_known = true ;
96+ // Forward operator inference.
97+ ishape.resize (num_inputs, empty_val);
98+ for (uint32_t i = 0 ; i < ishape.size (); ++i) {
99+ ishape[i] = rshape[idx.entry_id (inode.inputs [i])];
100+ if (fis_none (ishape[i])) forward_known = false ;
101+ }
102+ oshape.resize (num_outputs, empty_val);
103+ for (uint32_t i = 0 ; i < oshape.size (); ++i) {
104+ oshape[i] = rshape[idx.entry_id (nid, i)];
105+ if (fis_none (oshape[i])) forward_known = false ;
106+ }
107+ if (!forward_known) {
108+ auto finfer = finfer_shape.get (inode.source ->op (), fdefault);
109+ CHECK (finfer != nullptr )
110+ << " Attribute " << infer_name
111+ << " is not registed by op " << inode.source ->op ()->name ;
112+ // Call inference function of the operator.
113+ forward_known = finfer (inode.source ->attrs , &ishape, &oshape);
114+ }
115+ num_unknown += !forward_known;
116+ // Save to the result map.
117+ for (uint32_t i = 0 ; i < num_inputs; ++i) {
118+ rshape[idx.entry_id (inode.inputs [i])] = ishape[i];
119+ }
120+ for (uint32_t i = 0 ; i < num_outputs; ++i) {
121+ rshape[idx.entry_id (nid, i)] = oshape[i];
122+ }
123+ }
124+ };
125+
126+ num_unknown = 0 ;
127+ for (uint32_t nid = 0 ; nid < idx.num_nodes (); ++nid) {
128+ infer_step (nid);
129+ }
130+ if (num_unknown != 0 ) {
131+ num_unknown = 0 ;
132+ // backward inference
133+ for (uint32_t i = idx.num_nodes (); i != 0 ; --i) {
134+ infer_step (i - 1 );
114135 }
115136 }
116137 // set the shapes
@@ -127,19 +148,48 @@ NNVM_REGISTER_PASS(InferShape)
127148 std::move (ret), TShape (),
128149 " FInferShape" , " shape_inputs" , " shape_attr_key" ,
129150 " shape" , " shape_num_unknown_nodes" ,
130- [](const TShape& s) { return s.ndim () == 0 ; });
151+ [](const TShape& s) { return s.ndim () == 0 ; },
152+ nullptr );
131153 })
132154.set_change_graph(false )
133155.provide_graph_attr(" shape" );
134156
157+ // inference fucntion for same type
158+ inline bool SameType (const NodeAttrs& attrs,
159+ std::vector<int > *iattr,
160+ std::vector<int > *oattr) {
161+ int def_v = -1 ;
162+ for (int v : *oattr) {
163+ if (v != -1 ) {
164+ def_v = v; break ;
165+ }
166+ }
167+ if (def_v == -1 ) {
168+ for (int v : *iattr) {
169+ if (v != -1 ) {
170+ def_v = v; break ;
171+ }
172+ }
173+ }
174+ if (def_v == -1 ) return false ;
175+ for (int & v : *oattr) {
176+ v = def_v;
177+ }
178+ for (int & v : *iattr) {
179+ v = def_v;
180+ }
181+ return true ;
182+ }
183+
135184NNVM_REGISTER_PASS (InferType)
136185.describe(" Infer the dtype of each node entries." )
137186.set_body([](Graph ret) {
138187 return InferAttr<int >(
139- std::move (ret), 0 ,
188+ std::move (ret), - 1 ,
140189 " FInferType" , " dtype_inputs" , " dtype_attr_key" ,
141190 " dtype" , " dtype_num_unknown_nodes" ,
142- [](const int t) { return t == -1 ; });
191+ [](const int t) { return t == -1 ; },
192+ SameType);
143193 })
144194.set_change_graph(false )
145195.provide_graph_attr(" dtype" );
0 commit comments