Skip to content

Commit 79ceb9f

Browse files
committed
[PASS] PrintGraphIR Join attributes when print ir (apache#20)
1 parent c829bd8 commit 79ceb9f

File tree

5 files changed

+112
-5
lines changed

5 files changed

+112
-5
lines changed

nnvm/python/nnvm/graph.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,23 @@ def index(self):
177177
self._index = GraphIndex(self)
178178
return self._index
179179

180-
def graphir(self):
181-
"""Get text form of graph ir."""
180+
def ir(self, join_entry_attrs=None, join_node_attrs=None):
181+
"""Get text form of graph ir.
182+
183+
Parameters
184+
----------
185+
join_entry_attrs : list of str
186+
List of graph NodeEntry attribute to be
187+
printed along each operator.
188+
189+
join_node_attrs : list of str
190+
List of graph node attribute to be
191+
printed along each operator.
192+
"""
193+
if join_entry_attrs:
194+
self._set_json_attr("join_entry_attrs", join_entry_attrs, "list_str")
195+
if join_node_attrs:
196+
self._set_json_attr("join_node_attrs", join_node_attrs, "list_str")
182197
return self.apply("PrintGraphIR").json_attr("graphir")
183198

184199
def apply(self, passes):

nnvm/src/pass/infer_shape_type.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ Graph InferAttr(Graph &&ret,
6767
shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
6868
// erase the provided arguments
6969
ret.attrs.erase(attr_key_name);
70+
} else {
71+
shape_attr_key = attr_name;
7072
}
7173
// Temp space for shape inference.
7274
std::vector<AttrType> ishape, oshape;

nnvm/src/pass/print_graph_ir.cc

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,80 @@
55
*/
66
#include <nnvm/graph.h>
77
#include <nnvm/pass.h>
8+
#include <nnvm/tuple.h>
89
#include <iostream>
910

1011
namespace nnvm {
1112
namespace pass {
1213

14+
using AttrPrinter = std::function<void(uint32_t index, std::ostream& os)>; // NOLINT(*)
15+
16+
template<typename T>
17+
AttrPrinter GetVectorPrinter_(const T& vec) {
18+
return [&vec](uint32_t index, std::ostream& os) { // NOLINT(*)
19+
os << vec[index];
20+
};
21+
}
22+
23+
AttrPrinter GetVectorPrinter(const Graph& graph,
24+
const std::string& key) {
25+
auto it = graph.attrs.find(key);
26+
CHECK(it != graph.attrs.end())
27+
<< "Cannot find " << key << " in graph attr";
28+
const any& value = *(it->second);
29+
if (value.type() == typeid(std::vector<TShape>)) {
30+
return GetVectorPrinter_(
31+
nnvm::get<std::vector<TShape> >(value));
32+
} else if (value.type() == typeid(std::vector<int>)) {
33+
return GetVectorPrinter_(
34+
nnvm::get<std::vector<int> >(value));
35+
} else if (value.type() == typeid(std::vector<std::string>)) {
36+
return GetVectorPrinter_(
37+
nnvm::get<std::vector<std::string> >(value));
38+
} else {
39+
LOG(FATAL) << "Cannot handle type " << value.type().name();
40+
return nullptr;
41+
}
42+
}
43+
44+
1345
// print the graph ir in readable format
14-
void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*)
46+
void PrintGraphIR_(Graph src,
47+
const std::vector<std::string>& join_entry_attrs,
48+
const std::vector<std::string>& join_node_attrs,
49+
std::ostream& os) { // NOLINT(*)
1550
const IndexedGraph& idx = src.indexed_graph();
51+
std::vector<std::function<void(uint32_t, std::ostream&)> > trigger; // NOLINT(*)
52+
53+
for (const std::string& key : join_entry_attrs) {
54+
AttrPrinter fp = GetVectorPrinter(src, key);
55+
auto fprint = [&idx, key, fp](
56+
uint32_t nid, std::ostream& os) { // NOLINT(*)
57+
const IndexedGraph::Node& inode = idx[nid];
58+
os << ", " << key << "=";
59+
if (inode.source->num_outputs() != 1) {
60+
os << '[';
61+
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
62+
if (i != 0) os << ", ";
63+
fp(idx.entry_id(nid, i), os);
64+
}
65+
os << ']';
66+
} else {
67+
fp(idx.entry_id(nid, 0), os);
68+
}
69+
};
70+
trigger.push_back(fprint);
71+
}
72+
for (const std::string& key : join_node_attrs) {
73+
AttrPrinter fp = GetVectorPrinter(src, key);
74+
auto fprint = [&idx, key, fp](
75+
uint32_t nid, std::ostream& os) { // NOLINT(*)
76+
os << key << "=";
77+
fp(idx.entry_id(nid, 0), os);
78+
};
79+
trigger.push_back(fprint);
80+
}
81+
1682
os << "Graph(";
1783
if (idx.input_nodes().size() < 4) {
1884
for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
@@ -79,6 +145,10 @@ void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*)
79145
}
80146
os << "]";
81147
}
148+
// additional attribute trigger
149+
for (const auto& fp : trigger) {
150+
fp(nid, os);
151+
}
82152
os << "\n";
83153
}
84154
os << " ret ";
@@ -112,7 +182,16 @@ void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*)
112182
// save a graph to json
113183
Graph PrintGraphIR(Graph src) {
114184
std::ostringstream os;
115-
PrintGraphIR_(src, os);
185+
std::vector<std::string> join_entry_attrs, join_node_attrs;
186+
if (src.attrs.count("join_entry_attrs") != 0) {
187+
join_entry_attrs = src.MoveCopyAttr<std::vector<std::string> >(
188+
"join_entry_attrs");
189+
}
190+
if (src.attrs.count("join_node_attrs") != 0) {
191+
join_node_attrs = src.MoveCopyAttr<std::vector<std::string> >(
192+
"join_node_attrs");
193+
}
194+
PrintGraphIR_(src, join_entry_attrs, join_node_attrs, os);
116195
Graph ret;
117196
ret.attrs["graphir"] = std::make_shared<any>(os.str());
118197
return ret;

nnvm/tests/python/compiler/test_simplify_batchnorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def check(dim, axis, nstep):
3838
graph_attr.set_shape_inputs(g, ishape)
3939
g1 = g.apply("InferShape").apply("SimplifyBatchNormInference")
4040
# Some prints for debug
41-
# print(g1.graphir())
41+
# print(g1.ir())
4242
# assert graph equals as expected
4343
graph_pass.check_graph_equal(g1, g2)
4444

nnvm/tests/python/unittest/test_graph.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,19 @@ def test_plan_memory():
9999
assert (storage_id[jnode_row_ptr[nindex["add2"]]] ==
100100
storage_id[jnode_row_ptr[nindex["reshapek"]]])
101101

102+
def test_print_graph_ir():
103+
x = sym.Variable("x", shape=(1, 1, 10, 20))
104+
y = sym.conv2d(x + 1, name="y", channels=10, kernel_size=(3,3))
105+
g = graph.create(y)
106+
g = g.apply("InferShape")
107+
ir1 = g.ir()
108+
ir2 = g.ir(join_entry_attrs=["shape"])
109+
assert("y_bias" in ir1)
110+
assert("shape=" in ir2)
111+
102112

103113
if __name__ == "__main__":
114+
test_print_graph_ir()
104115
test_json_pass_with_attr()
105116
test_graph_json_attr()
106117
test_json_pass()

0 commit comments

Comments
 (0)