Skip to content

Commit 706e8c0

Browse files
committed
do hint insertion after aggregation (#81)
1 parent 272dd98 commit 706e8c0

File tree

3 files changed

+19
-18
lines changed

3 files changed

+19
-18
lines changed

nnvm/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ endif
1010
include $(config)
1111

1212
export LDFLAGS = -pthread -lm
13-
export CFLAGS = -std=c++11 -Wall -O2 -msse2 -Wno-unknown-pragmas -funroll-loops\
13+
export CFLAGS = -std=c++11 -Wall -O2 -msse2 -Wno-unknown-pragmas -funroll-loops\
1414
-Iinclude -fPIC
1515

1616
ifneq ($(ADD_CFLAGS), NONE)

nnvm/make/config.mk

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,13 @@
2020
# choice of compiler
2121
#--------------------
2222

23-
export CC = gcc
24-
export CXX = g++
2523
export NVCC = nvcc
2624

2725
# the additional link flags you want to add
28-
ADD_LDFLAGS =
26+
ADD_LDFLAGS=
2927

3028
# the additional compile flags you want to add
31-
ADD_CFLAGS =
29+
ADD_CFLAGS=
3230

3331
#----------------------------
3432
# plugins

nnvm/src/pass/gradient.cc

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ struct GradEntry {
3838
NodeEntry sum{nullptr, 0, 0};
3939
#endif
4040
std::vector<NodeEntry> grads;
41+
bool need_attr_hint{true};
4142
};
4243

4344
Graph Gradient(Graph src) {
@@ -85,9 +86,6 @@ Graph Gradient(Graph src) {
8586
CHECK_EQ(ys.size(), ys_out_grad.size());
8687
for (size_t i = 0; i < ys.size(); ++i) {
8788
NodeEntry ograd = ys_out_grad[i];
88-
if (attr_hint_fun != nullptr) {
89-
ograd = attr_hint_fun(ograd, ys[i]);
90-
}
9189
output_grads[ys[i].node.get()][ys[i].index].grads = { ograd };
9290
}
9391

@@ -121,27 +119,29 @@ Graph Gradient(Graph src) {
121119
const NodePtr& ptr = *rit;
122120
if (ptr->is_variable()) continue;
123121
out_agg_grads.clear();
124-
for (GradEntry& e : output_grads.at(ptr.get())) {
122+
auto& out_grad_vec = output_grads.at(ptr.get());
123+
for (uint32_t i = 0; i < out_grad_vec.size(); ++i) {
124+
GradEntry& e = out_grad_vec[i];
125125
e.sum = agg_fun(std::move(e.grads));
126+
if (e.need_attr_hint && attr_hint_fun != nullptr) {
127+
e.sum = attr_hint_fun(e.sum, NodeEntry{ptr, 0, i});
128+
}
126129
out_agg_grads.push_back(e.sum);
127130
}
128131
if ((*rit)->inputs.size() != 0) {
129132
NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
130133
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()](
131134
fwd_node, out_agg_grads);
132-
133-
if (attr_hint_fun != nullptr) {
134-
// only insert hint when shape inference function is not available.
135-
for (size_t i = 0; i < input_grads.size(); ++i) {
136-
if (finfer_shape.count(input_grads[i].node->op())) continue;
137-
input_grads[i] = attr_hint_fun(input_grads[i], fwd_node->inputs[i]);
138-
}
139-
}
140135
CHECK_EQ((*rit)->inputs.size(), input_grads.size())
141136
<< "Gradient function not returning enough gradient";
142137
auto git = input_grads.begin();
143138
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
144-
output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git));
139+
auto& ge = output_grads[it->node.get()][it->index];
140+
// if any of the backward op can do shape inference, the hint is not necessary.
141+
if (finfer_shape.count(git->node->op())) {
142+
ge.need_attr_hint = false;
143+
}
144+
ge.grads.emplace_back(std::move(*git));
145145
}
146146
}
147147
}
@@ -153,6 +153,9 @@ Graph Gradient(Graph src) {
153153
// aggregate sum if there haven't been
154154
if (entry.sum.node.get() == nullptr) {
155155
entry.sum = agg_fun(std::move(entry.grads));
156+
if (entry.need_attr_hint && attr_hint_fun != nullptr) {
157+
entry.sum = attr_hint_fun(entry.sum, e);
158+
}
156159
}
157160
ret.outputs.emplace_back(std::move(entry.sum));
158161
}

0 commit comments

Comments
 (0)