Skip to content

Commit e8fee6d

Browse files
committed
Add shape backward inference (apache#58)
1 parent 869a953 commit e8fee6d

File tree

2 files changed

+45
-14
lines changed

2 files changed

+45
-14
lines changed

nnvm/include/nnvm/op_attr_types.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,21 @@ using FInferType = FInferNodeEntryAttr<int>;
107107
using FBackwardOutToInIndex = std::function<
108108
std::vector<uint32_t> (const NodeAttrs& attrs)>;
109109

110+
/*!
111+
* \brief Whether this op is an explicit backward operator,
112+
* Returns list of input index that corresponds to the outputs of the forward operator.
113+
*
114+
* If FBackwardInGradIndex exists:
115+
* - The first control_deps of the node points to the corresponding forward operator.
116+
* - The FBackwardInGradIndex[i]-th input of backward op corresponds to the i-th
117+
* output of forward operator.
118+
*
119+
* \note Register under "FBackwardInGradIndex"
120+
* This enables easier shape/type inference for backward operators.
121+
*/
122+
using FBackwardInGradIndex = std::function<
123+
std::vector<uint32_t> (const NodeAttrs& attrs)>;
124+
110125
/*!
111126
* \brief Get possible inplace options.
112127
* This function enables optimization to reuse memory of inputs in output.

nnvm/src/pass/infer_shape_type.cc

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ Graph InferAttr(Graph &&ret,
2727
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
2828
static auto& backward_map =
2929
Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex");
30+
static auto& backward_in_grad =
31+
Op::GetAttr<FBackwardInGradIndex>("FBackwardInGradIndex");
3032
// reshape shape vector
3133
AttrVector rshape;
3234
if (ret.attrs.count(attr_name) != 0) {
@@ -54,7 +56,6 @@ Graph InferAttr(Graph &&ret,
5456
}
5557
// Temp space for shape inference.
5658
std::vector<AttrType> ishape, oshape;
57-
size_t num_unknown;
5859

5960
// inference step function for nid
6061
auto infer_step = [&](uint32_t nid) {
@@ -76,21 +77,29 @@ Graph InferAttr(Graph &&ret,
7677
} else if (backward_map.count(inode.source->op())) {
7778
// Backward operator inference.
7879
CHECK_GE(inode.control_deps.size(), 1)
79-
<< "BackwardOp need to have control_deps to its forward op";
80+
<< "BackwardOp need to have control_deps to its forward op";
8081
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
8182
// Inference the outputs of backward operator (equal to the inputs
8283
// of its corresponding forward operator).
8384
std::vector<uint32_t> out_map =
8485
backward_map[inode.source->op()](inode.source->attrs);
85-
bool known = true;
8686
for (size_t i = 0; i < out_map.size(); ++i) {
8787
uint32_t in_id = out_map[i];
8888
CHECK_LT(in_id, fnode.inputs.size());
8989
rshape[idx.entry_id(nid, i)] =
9090
rshape[idx.entry_id(fnode.inputs[in_id])];
91-
if (fis_none(rshape[idx.entry_id(nid, i)])) known = false;
9291
}
93-
num_unknown += !known;
92+
if (backward_in_grad.count(inode.source->op())) {
93+
std::vector<uint32_t> in_grad =
94+
backward_in_grad[inode.source->op()](inode.source->attrs);
95+
CHECK_LE(in_grad.size(), fnode.source->num_outputs());
96+
for (size_t i = 0; i < in_grad.size(); ++i) {
97+
uint32_t eid = idx.entry_id(inode.inputs[in_grad[i]]);
98+
if (fis_none(rshape[eid])) {
99+
rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], i)];
100+
}
101+
}
102+
}
94103
} else {
95104
bool forward_known = true;
96105
// Forward operator inference.
@@ -112,7 +121,6 @@ Graph InferAttr(Graph &&ret,
112121
// Call inference function of the operator.
113122
forward_known = finfer(inode.source->attrs, &ishape, &oshape);
114123
}
115-
num_unknown += !forward_known;
116124
// Save to the result map.
117125
for (uint32_t i = 0; i < num_inputs; ++i) {
118126
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
@@ -123,16 +131,24 @@ Graph InferAttr(Graph &&ret,
123131
}
124132
};
125133

126-
num_unknown = 0;
127-
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
128-
infer_step(nid);
129-
}
130-
if (num_unknown != 0) {
134+
size_t num_unknown = 0;
135+
const int kMaxStep = 3;
136+
for (int i = 0; i < kMaxStep; ++i) {
137+
if (i % 2 == 0) {
138+
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
139+
infer_step(nid);
140+
}
141+
} else {
142+
// backward inference
143+
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
144+
infer_step(i - 1);
145+
}
146+
}
131147
num_unknown = 0;
132-
// backward inference
133-
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
134-
infer_step(i - 1);
148+
for (size_t i = 0; i < idx.num_node_entries(); ++i) {
149+
if (fis_none(rshape[i])) ++num_unknown;
135150
}
151+
if (num_unknown == 0) break;
136152
}
137153
// set the shapes
138154
ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));

0 commit comments

Comments
 (0)