Skip to content

Commit a185e85

Browse files
committed
[PASS] PrintGraphIR, SimplifyBatchNormInference (apache#19)
1 parent 147a13f commit a185e85

File tree

19 files changed

+650
-24
lines changed

19 files changed

+650
-24
lines changed

nnvm/include/nnvm/node.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,33 @@ class Node {
128128
static NodePtr Create();
129129
};
130130

131+
/*!
132+
* \brief Quick utilities make node.
133+
* \param op_name The name of operator
134+
* \param node_name The name of the node
135+
* \param inputs The input entries
136+
* \param attrs The attributes
137+
* \return The created node entry.
138+
*/
139+
inline NodeEntry MakeNode(
140+
const char* op_name,
141+
std::string node_name,
142+
std::vector<NodeEntry> inputs,
143+
std::unordered_map<std::string, std::string> attrs =
144+
std::unordered_map<std::string, std::string>()) {
145+
NodePtr p = Node::Create();
146+
p->attrs.op = nnvm::Op::Get(op_name);
147+
p->attrs.name = std::move(node_name);
148+
if (attrs.size() != 0) {
149+
p->attrs.dict = attrs;
150+
if (p->attrs.op->attr_parser) {
151+
p->attrs.op->attr_parser(&(p->attrs));
152+
}
153+
}
154+
p->inputs = std::move(inputs);
155+
return NodeEntry{p, 0, 0};
156+
}
157+
131158
// implementation of functions.
132159
inline const Op* Node::op() const {
133160
return this->attrs.op;

nnvm/python/nnvm/compiler/graph_attr.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,5 @@ def set_layout_inputs(g, layout):
8383
g._set_json_attr("layout_inputs", list_shape, 'list_str')
8484
return g
8585

86-
87-
_move_out_module = tvm.get_global_func("nnvm.graph_attr._move_module")
88-
_move_out_graph = tvm.get_global_func("nnvm.graph_attr._move_graph")
86+
_move_out_module = tvm.get_global_func("nnvm.graph._move_module")
87+
_move_out_graph = tvm.get_global_func("nnvm.graph._move_graph")

nnvm/python/nnvm/compiler/graph_pass.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88
from __future__ import absolute_import as _abs
99

10+
import tvm
1011
from . import graph_attr
1112

1213

@@ -60,3 +61,26 @@ def infer_dtype(graph, **dtype):
6061
output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
6162
for x in index.output_entries]
6263
return input_dtype, output_dtype
64+
65+
66+
_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare")
67+
68+
def check_graph_equal(grapha, graphb):
69+
"""Check if two graphs have equal structure.
70+
71+
Parameters
72+
----------
73+
grapha : Graph
74+
The first graph
75+
76+
graphb : Graph
77+
The second graph
78+
79+
Raises
80+
------
81+
ValueError
82+
ValueError is raised with error message when graph not equal
83+
"""
84+
err = _deep_compare(grapha, graphb)
85+
if err:
86+
raise ValueError("Graph compare error: " + err)

nnvm/python/nnvm/graph.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ def index(self):
177177
self._index = GraphIndex(self)
178178
return self._index
179179

180+
def graphir(self):
181+
"""Get text form of graph ir."""
182+
return self.apply("PrintGraphIR").json_attr("graphir")
183+
180184
def apply(self, passes):
181185
"""Apply passes to the graph
182186

nnvm/python/nnvm/top/nn.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
# relu
1212
@reg.register_compute("relu")
13-
def compute_relu(attrs, inputs):
13+
def compute_relu(_, inputs):
1414
"""Compute definition of relu"""
1515
return topi.nn.relu(inputs[0])
1616

@@ -72,8 +72,7 @@ def schedule_conv2d(attrs, outs, target):
7272
if target == "cuda":
7373
if groups == 1:
7474
return topi.cuda.schedule_conv2d_nchw(outs)
75-
else:
76-
return topi.cuda.schedule_depthwise_conv2d_nchw(outs)
75+
return topi.cuda.schedule_depthwise_conv2d_nchw(outs)
7776
# naive schedule
7877
return tvm.create_schedule([x.op for x in outs])
7978

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*!
2+
* Copyright (c) 2017 by Contributors
3+
* \file graph_deep_compare.cc
4+
* \brief Deep compare two graph structure
5+
*/
6+
#include <nnvm/graph.h>
7+
#include <nnvm/op_attr_types.h>
8+
#include <nnvm/compiler/packed_func_ext.h>
9+
#include <tvm/runtime/packed_func.h>
10+
#include "./node_attr.h"
11+
12+
namespace nnvm {
13+
namespace compiler {
14+
15+
// deep compare the graph structure
16+
// not considering the graph attributes
17+
// return non-empty error message if the graph mismatch.
18+
// the comparator won't match name of intermediate node.
19+
std::string DeepCompare(Graph a, Graph b) {
20+
const IndexedGraph& idxa = a.indexed_graph();
21+
const IndexedGraph& idxb = b.indexed_graph();
22+
std::ostringstream err;
23+
if (idxa.num_nodes() != idxb.num_nodes()) {
24+
err << "Number of nodes mismatch";
25+
return err.str();
26+
}
27+
if (idxa.num_node_entries() != idxb.num_node_entries()) {
28+
err << "Number of node entry mismatch";
29+
return err.str();
30+
}
31+
if (idxa.outputs().size() != idxb.outputs().size()) {
32+
err << "Number of outputs mismatch";
33+
return err.str();
34+
}
35+
for (size_t i = 0; i < idxa.outputs().size(); ++i) {
36+
if (idxa.outputs()[i].node_id != idxb.outputs()[i].node_id ||
37+
idxa.outputs()[i].index != idxb.outputs()[i].index) {
38+
err << "Output entry mismatch";
39+
return err.str();
40+
}
41+
}
42+
if (idxa.input_nodes().size() != idxb.input_nodes().size()) {
43+
err << "Number of inputs mismatch";
44+
return err.str();
45+
}
46+
47+
for (uint32_t nid = 0; nid < idxa.num_nodes(); ++nid) {
48+
const IndexedGraph::Node& anode = idxa[nid];
49+
const IndexedGraph::Node& bnode = idxb[nid];
50+
if (anode.source->op() != bnode.source->op()) {
51+
err << "Node mismatch ";
52+
return err.str();
53+
}
54+
AttrDict adict = GetAttrDict(anode.source->attrs);
55+
AttrDict bdict = GetAttrDict(bnode.source->attrs);
56+
57+
auto fmatch = [&err, &anode](const AttrDict& adict, const AttrDict& bdict) {
58+
for (const auto& kv : adict) {
59+
auto it = bdict.find(kv.first);
60+
if (it != bdict.end()) {
61+
if (it->second != kv.second) {
62+
err << "Node attr mismatch, op=" << anode.source->attrs.name
63+
<< " attr_key=" << kv.first << " " << it->second
64+
<< " v.s. " << kv.second;
65+
return false;
66+
}
67+
} else {
68+
err << "One attr_key=" << kv.first << " is missing in another "
69+
<< "op=" << anode.source->attrs.name;
70+
return false;
71+
}
72+
}
73+
return true;
74+
};
75+
if (!fmatch(adict, bdict)) return err.str();
76+
if (adict.size() != bdict.size()) {
77+
CHECK(!fmatch(bdict, adict));
78+
return err.str();
79+
}
80+
if (anode.inputs.size() != bnode.inputs.size()) {
81+
err << "Node input mismatch, op=" << anode.source->attrs.name;
82+
return err.str();
83+
}
84+
if (anode.control_deps.size() != bnode.control_deps.size()) {
85+
err << "Node control_deps mistach, op=" << anode.source->attrs.name;
86+
return err.str();
87+
}
88+
for (size_t i = 0; i < anode.inputs.size(); ++i) {
89+
const IndexedGraph::NodeEntry& ae = anode.inputs[i];
90+
const IndexedGraph::NodeEntry& be = bnode.inputs[i];
91+
if (ae.node_id != be.node_id ||
92+
ae.index != be.index ||
93+
ae.version != be.version) {
94+
err << "Node input mismatch on, op=" << anode.source->attrs.name;
95+
return err.str();
96+
}
97+
}
98+
for (size_t i = 0; i < anode.control_deps.size(); ++i) {
99+
if (anode.control_deps[i] != bnode.control_deps[i]) {
100+
err << "Node control_dep mismatch on, op=" << anode.source->attrs.name;
101+
return err.str();
102+
}
103+
}
104+
}
105+
return "";
106+
}
107+
108+
TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare")
109+
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
110+
*rv = DeepCompare(args[0], args[1]);
111+
});
112+
} // namespace compiler
113+
} // namespace nnvm

nnvm/src/compiler/pass/graph_fuse.cc renamed to nnvm/src/compiler/graph_fuse.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#include <tvm/runtime/packed_func.h>
1414
#include <tvm/operation.h>
1515
#include <tvm/lowered_func.h>
16-
#include "../../runtime/graph_executor.h"
16+
#include "../runtime/graph_executor.h"
1717

1818
namespace nnvm {
1919
namespace compiler {
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*!
2+
* Copyright (c) 2017 by Contributors
3+
* \file graph_transform.h
4+
* \brief A mutator class that does local pattern matching and mutates a node.
5+
*/
6+
#ifndef NNVM_COMPILER_GRAPH_TRANSFORM_H_
7+
#define NNVM_COMPILER_GRAPH_TRANSFORM_H_
8+
9+
#include <nnvm/graph.h>
10+
#include <vector>
11+
12+
namespace nnvm {
13+
namespace compiler {
14+
15+
/*!
16+
* \brief Transform the graph to build a new Graph, in post DFS order.
17+
*
18+
* Automatically copies node when some of its children or control_deps changed.
19+
* This function won't be called in Variable.
20+
*
21+
* \param graph The original graph
22+
*
23+
* \param ftransform Function of (int nid, const Node* node, std::vector<NodeEntry>* out) -> bool
24+
*
25+
* If empty vector is returned, it means original entries should be kept.
26+
*
27+
* \tparam FTransform The transformation function.
28+
*/
29+
template<typename FTransform>
30+
Graph GraphTransform(Graph graph, FTransform ftransform) {
31+
const IndexedGraph& idx = graph.indexed_graph();
32+
// new nodes
33+
std::vector<NodeEntry> new_entry_map(idx.num_node_entries());
34+
std::vector<bool> updated(idx.num_node_entries(), false);
35+
36+
// setup inputs and placeholder.
37+
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
38+
const auto& inode = idx[nid];
39+
if (inode.source->is_variable()) continue;
40+
bool need_copy = false;
41+
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
42+
if (updated[idx.entry_id(e)]) {
43+
need_copy = true; break;
44+
}
45+
}
46+
if (!need_copy) {
47+
for (const uint32_t cid : inode.control_deps) {
48+
const auto& cnode = idx[cid];
49+
for (uint32_t i = 0 ; i < cnode.source->num_outputs(); ++i) {
50+
if (updated[idx.entry_id(cid, i)]) {
51+
need_copy = true;
52+
}
53+
}
54+
if (need_copy) break;
55+
}
56+
}
57+
58+
if (!need_copy) {
59+
std::vector<NodeEntry> ret;
60+
if (ftransform(nid, inode.source, &ret)) {
61+
CHECK_EQ(ret.size(), static_cast<size_t>(inode.source->num_outputs()));
62+
for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) {
63+
updated[idx.entry_id(nid, i)] = true;
64+
new_entry_map[idx.entry_id(nid, i)] = ret[i];
65+
}
66+
}
67+
} else {
68+
NodePtr node = Node::Create();
69+
node->attrs = inode.source->attrs;
70+
for (size_t i = 0; i < inode.inputs.size(); ++i) {
71+
const IndexedGraph::NodeEntry& e = inode.inputs[i];
72+
if (updated[idx.entry_id(e)]) {
73+
node->inputs.push_back(new_entry_map[idx.entry_id(e)]);
74+
} else {
75+
node->inputs.push_back(inode.source->inputs[i]);
76+
}
77+
}
78+
for (size_t i = 0; i < inode.control_deps.size(); ++i) {
79+
const uint32_t cid = inode.control_deps[i];
80+
const auto& cnode = idx[cid];
81+
CHECK_NE(cnode.source->num_outputs(), 0U);
82+
NodePtr selected_ptr;
83+
for (uint32_t j = 0 ; j < cnode.source->num_outputs(); ++j) {
84+
NodePtr cptr = updated[idx.entry_id(cid, j)] ?
85+
new_entry_map[idx.entry_id(cid, j)].node : inode.source->control_deps[i];
86+
if (selected_ptr == nullptr) {
87+
selected_ptr = std::move(cptr);
88+
} else {
89+
CHECK(selected_ptr.get() == cptr.get())
90+
<< "Control dependency node changed to more than one node";
91+
}
92+
}
93+
node->control_deps.push_back(selected_ptr);
94+
}
95+
std::vector<NodeEntry> ret;
96+
if (ftransform(nid, node.get(), &ret)) {
97+
CHECK_EQ(ret.size(), static_cast<size_t>(inode.source->num_outputs()));
98+
for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) {
99+
updated[idx.entry_id(nid, i)] = true;
100+
new_entry_map[idx.entry_id(nid, i)] = ret[i];
101+
}
102+
} else {
103+
for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) {
104+
updated[idx.entry_id(nid, i)] = true;
105+
new_entry_map[idx.entry_id(nid, i)] = NodeEntry{node, i, 0};
106+
}
107+
}
108+
}
109+
}
110+
Graph ret;
111+
for (size_t i = 0; i < idx.outputs().size(); ++i) {
112+
const IndexedGraph::NodeEntry& e = idx.outputs()[i];
113+
if (updated[idx.entry_id(e)]) {
114+
ret.outputs.push_back(new_entry_map[idx.entry_id(e)]);
115+
} else {
116+
ret.outputs.push_back(graph.outputs[i]);
117+
}
118+
}
119+
return ret;
120+
}
121+
122+
} // namespace compiler
123+
} // namespace nnvm
124+
125+
#endif // NNVM_COMPILER_GRAPH_TRANSFORM_H_
File renamed without changes.

nnvm/src/compiler/node_attr.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*!
2+
* Copyright (c) 2017 by Contributors
3+
* \file node_attr.h
4+
* \brief utility to access node attributes
5+
*/
6+
#ifndef NNVM_COMPILER_NODE_ATTR_H_
7+
#define NNVM_COMPILER_NODE_ATTR_H_
8+
9+
#include <nnvm/op.h>
10+
#include <nnvm/compiler/op_attr_types.h>
11+
#include <unordered_map>
12+
#include <string>
13+
14+
namespace nnvm {
15+
namespace compiler {
16+
17+
using AttrDict = std::unordered_map<std::string, std::string>;
18+
/*!
19+
* \brief Get canonicalized attr dict from node
20+
* \param attrs The node attrs
21+
* \return The attribute dict
22+
*/
23+
inline AttrDict GetAttrDict(const NodeAttrs& attrs) {
24+
static auto& fgetdict = nnvm::Op::GetAttr<FGetAttrDict>("FGetAttrDict");
25+
if (fgetdict.count(attrs.op)) {
26+
return fgetdict[attrs.op](attrs);
27+
} else {
28+
return attrs.dict;
29+
}
30+
}
31+
32+
} // namespace compiler
33+
} // namespace nnvm
34+
#endif // NNVM_COMPILER_NODE_ATTR_H_

0 commit comments

Comments
 (0)