Skip to content

Commit 45da871

Browse files
committed
[Infer] More robust inference, support backward inference (apache#54)
1 parent 647267d commit 45da871

File tree

1 file changed

+82
-32
lines changed

1 file changed

+82
-32
lines changed

nnvm/src/pass/infer_shape_type.cc

Lines changed: 82 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,16 @@ namespace nnvm {
1111
namespace pass {
1212
namespace {
1313

14-
template<typename AttrType, typename IsNone>
14+
template<typename AttrType, typename IsNone, typename FDefault>
1515
Graph InferAttr(Graph &&ret,
16-
const AttrType default_val,
16+
const AttrType empty_val,
1717
const char* infer_name,
1818
const char* input_name,
1919
const char* attr_key_name,
2020
const char* attr_name,
2121
const char* unknown_name,
22-
IsNone fis_none) {
22+
IsNone fis_none,
23+
FDefault fdefault) {
2324
using AttrVector = std::vector<AttrType>;
2425
const IndexedGraph& idx = ret.indexed_graph();
2526
static auto& finfer_shape =
@@ -31,7 +32,7 @@ Graph InferAttr(Graph &&ret,
3132
if (ret.attrs.count(attr_name) != 0) {
3233
rshape = ret.MoveCopyAttr<AttrVector>(attr_name);
3334
} else {
34-
rshape.resize(idx.num_node_entries(), default_val);
35+
rshape.resize(idx.num_node_entries(), empty_val);
3536
}
3637

3738
if (ret.attrs.count(input_name) != 0) {
@@ -51,12 +52,12 @@ Graph InferAttr(Graph &&ret,
5152
// erase the provided arguments
5253
ret.attrs.erase(attr_key_name);
5354
}
54-
5555
// Temp space for shape inference.
5656
std::vector<AttrType> ishape, oshape;
57-
// number of completed nodes
58-
size_t num_unknown = 0;
59-
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
57+
size_t num_unknown;
58+
59+
// inference step function for nid
60+
auto infer_step = [&](uint32_t nid) {
6061
const auto& inode = idx[nid];
6162
const uint32_t num_inputs = inode.inputs.size();
6263
const uint32_t num_outputs = inode.source->num_outputs();
@@ -72,27 +73,6 @@ Graph InferAttr(Graph &&ret,
7273
CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
7374
}
7475
}
75-
} else if (finfer_shape.count(inode.source->op())) {
76-
// Forward operator inference.
77-
ishape.resize(num_inputs, default_val);
78-
for (uint32_t i = 0; i < ishape.size(); ++i) {
79-
ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
80-
}
81-
oshape.resize(num_outputs, default_val);
82-
for (uint32_t i = 0; i < oshape.size(); ++i) {
83-
oshape[i] = rshape[idx.entry_id(nid, i)];
84-
}
85-
// Call inference function of the operator.
86-
bool forward_known = finfer_shape[inode.source->op()](
87-
inode.source->attrs, &ishape, &oshape);
88-
num_unknown += !forward_known;
89-
// Save to the result map.
90-
for (uint32_t i = 0; i < num_inputs; ++i) {
91-
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
92-
}
93-
for (uint32_t i = 0; i < num_outputs; ++i) {
94-
rshape[idx.entry_id(nid, i)] = oshape[i];
95-
}
9676
} else if (backward_map.count(inode.source->op())) {
9777
// Backward operator inference.
9878
CHECK_GE(inode.control_deps.size(), 1)
@@ -111,6 +91,47 @@ Graph InferAttr(Graph &&ret,
11191
if (fis_none(rshape[idx.entry_id(nid, i)])) known = false;
11292
}
11393
num_unknown += !known;
94+
} else {
95+
bool forward_known = true;
96+
// Forward operator inference.
97+
ishape.resize(num_inputs, empty_val);
98+
for (uint32_t i = 0; i < ishape.size(); ++i) {
99+
ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
100+
if (fis_none(ishape[i])) forward_known = false;
101+
}
102+
oshape.resize(num_outputs, empty_val);
103+
for (uint32_t i = 0; i < oshape.size(); ++i) {
104+
oshape[i] = rshape[idx.entry_id(nid, i)];
105+
if (fis_none(oshape[i])) forward_known = false;
106+
}
107+
if (!forward_known) {
108+
auto finfer = finfer_shape.get(inode.source->op(), fdefault);
109+
CHECK(finfer != nullptr)
110+
<< "Attribute " << infer_name
111+
<< " is not registed by op " << inode.source->op()->name;
112+
// Call inference function of the operator.
113+
forward_known = finfer(inode.source->attrs, &ishape, &oshape);
114+
}
115+
num_unknown += !forward_known;
116+
// Save to the result map.
117+
for (uint32_t i = 0; i < num_inputs; ++i) {
118+
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
119+
}
120+
for (uint32_t i = 0; i < num_outputs; ++i) {
121+
rshape[idx.entry_id(nid, i)] = oshape[i];
122+
}
123+
}
124+
};
125+
126+
num_unknown = 0;
127+
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
128+
infer_step(nid);
129+
}
130+
if (num_unknown != 0) {
131+
num_unknown = 0;
132+
// backward inference
133+
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
134+
infer_step(i - 1);
114135
}
115136
}
116137
// set the shapes
@@ -127,19 +148,48 @@ NNVM_REGISTER_PASS(InferShape)
127148
std::move(ret), TShape(),
128149
"FInferShape", "shape_inputs", "shape_attr_key",
129150
"shape", "shape_num_unknown_nodes",
130-
[](const TShape& s) { return s.ndim() == 0; });
151+
[](const TShape& s) { return s.ndim() == 0; },
152+
nullptr);
131153
})
132154
.set_change_graph(false)
133155
.provide_graph_attr("shape");
134156

157+
// inference fucntion for same type
158+
inline bool SameType(const NodeAttrs& attrs,
159+
std::vector<int> *iattr,
160+
std::vector<int> *oattr) {
161+
int def_v = -1;
162+
for (int v : *oattr) {
163+
if (v != -1) {
164+
def_v = v; break;
165+
}
166+
}
167+
if (def_v == -1) {
168+
for (int v : *iattr) {
169+
if (v != -1) {
170+
def_v = v; break;
171+
}
172+
}
173+
}
174+
if (def_v == -1) return false;
175+
for (int& v : *oattr) {
176+
v = def_v;
177+
}
178+
for (int& v : *iattr) {
179+
v = def_v;
180+
}
181+
return true;
182+
}
183+
135184
NNVM_REGISTER_PASS(InferType)
136185
.describe("Infer the dtype of each node entries.")
137186
.set_body([](Graph ret) {
138187
return InferAttr<int>(
139-
std::move(ret), 0,
188+
std::move(ret), -1,
140189
"FInferType", "dtype_inputs", "dtype_attr_key",
141190
"dtype", "dtype_num_unknown_nodes",
142-
[](const int t) { return t == -1; });
191+
[](const int t) { return t == -1; },
192+
SameType);
143193
})
144194
.set_change_graph(false)
145195
.provide_graph_attr("dtype");

0 commit comments

Comments
 (0)