|
5 | 5 | */ |
6 | 6 | #include <nnvm/graph.h> |
7 | 7 | #include <nnvm/pass.h> |
| 8 | +#include <nnvm/tuple.h> |
8 | 9 | #include <iostream> |
9 | 10 |
|
10 | 11 | namespace nnvm { |
11 | 12 | namespace pass { |
12 | 13 |
|
| 14 | +using AttrPrinter = std::function<void(uint32_t index, std::ostream& os)>; // NOLINT(*) |
| 15 | + |
| 16 | +template<typename T> |
| 17 | +AttrPrinter GetVectorPrinter_(const T& vec) { |
| 18 | + return [&vec](uint32_t index, std::ostream& os) { // NOLINT(*) |
| 19 | + os << vec[index]; |
| 20 | + }; |
| 21 | +} |
| 22 | + |
| 23 | +AttrPrinter GetVectorPrinter(const Graph& graph, |
| 24 | + const std::string& key) { |
| 25 | + auto it = graph.attrs.find(key); |
| 26 | + CHECK(it != graph.attrs.end()) |
| 27 | + << "Cannot find " << key << " in graph attr"; |
| 28 | + const any& value = *(it->second); |
| 29 | + if (value.type() == typeid(std::vector<TShape>)) { |
| 30 | + return GetVectorPrinter_( |
| 31 | + nnvm::get<std::vector<TShape> >(value)); |
| 32 | + } else if (value.type() == typeid(std::vector<int>)) { |
| 33 | + return GetVectorPrinter_( |
| 34 | + nnvm::get<std::vector<int> >(value)); |
| 35 | + } else if (value.type() == typeid(std::vector<std::string>)) { |
| 36 | + return GetVectorPrinter_( |
| 37 | + nnvm::get<std::vector<std::string> >(value)); |
| 38 | + } else { |
| 39 | + LOG(FATAL) << "Cannot handle type " << value.type().name(); |
| 40 | + return nullptr; |
| 41 | + } |
| 42 | +} |
| 43 | + |
| 44 | + |
13 | 45 | // print the graph ir in readable format |
14 | | -void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*) |
| 46 | +void PrintGraphIR_(Graph src, |
| 47 | + const std::vector<std::string>& join_entry_attrs, |
| 48 | + const std::vector<std::string>& join_node_attrs, |
| 49 | + std::ostream& os) { // NOLINT(*) |
15 | 50 | const IndexedGraph& idx = src.indexed_graph(); |
| 51 | + std::vector<std::function<void(uint32_t, std::ostream&)> > trigger; // NOLINT(*) |
| 52 | + |
| 53 | + for (const std::string& key : join_entry_attrs) { |
| 54 | + AttrPrinter fp = GetVectorPrinter(src, key); |
| 55 | + auto fprint = [&idx, key, fp]( |
| 56 | + uint32_t nid, std::ostream& os) { // NOLINT(*) |
| 57 | + const IndexedGraph::Node& inode = idx[nid]; |
| 58 | + os << ", " << key << "="; |
| 59 | + if (inode.source->num_outputs() != 1) { |
| 60 | + os << '['; |
| 61 | + for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) { |
| 62 | + if (i != 0) os << ", "; |
| 63 | + fp(idx.entry_id(nid, i), os); |
| 64 | + } |
| 65 | + os << ']'; |
| 66 | + } else { |
| 67 | + fp(idx.entry_id(nid, 0), os); |
| 68 | + } |
| 69 | + }; |
| 70 | + trigger.push_back(fprint); |
| 71 | + } |
| 72 | + for (const std::string& key : join_node_attrs) { |
| 73 | + AttrPrinter fp = GetVectorPrinter(src, key); |
| 74 | + auto fprint = [&idx, key, fp]( |
| 75 | + uint32_t nid, std::ostream& os) { // NOLINT(*) |
| 76 | + os << key << "="; |
| 77 | + fp(idx.entry_id(nid, 0), os); |
| 78 | + }; |
| 79 | + trigger.push_back(fprint); |
| 80 | + } |
| 81 | + |
16 | 82 | os << "Graph("; |
17 | 83 | if (idx.input_nodes().size() < 4) { |
18 | 84 | for (size_t i = 0; i < idx.input_nodes().size(); ++i) { |
@@ -79,6 +145,10 @@ void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*) |
79 | 145 | } |
80 | 146 | os << "]"; |
81 | 147 | } |
| 148 | + // additional attribute trigger |
| 149 | + for (const auto& fp : trigger) { |
| 150 | + fp(nid, os); |
| 151 | + } |
82 | 152 | os << "\n"; |
83 | 153 | } |
84 | 154 | os << " ret "; |
@@ -112,7 +182,16 @@ void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*) |
112 | 182 | // save a graph to json |
113 | 183 | Graph PrintGraphIR(Graph src) { |
114 | 184 | std::ostringstream os; |
115 | | - PrintGraphIR_(src, os); |
| 185 | + std::vector<std::string> join_entry_attrs, join_node_attrs; |
| 186 | + if (src.attrs.count("join_entry_attrs") != 0) { |
| 187 | + join_entry_attrs = src.MoveCopyAttr<std::vector<std::string> >( |
| 188 | + "join_entry_attrs"); |
| 189 | + } |
| 190 | + if (src.attrs.count("join_node_attrs") != 0) { |
| 191 | + join_node_attrs = src.MoveCopyAttr<std::vector<std::string> >( |
| 192 | + "join_node_attrs"); |
| 193 | + } |
| 194 | + PrintGraphIR_(src, join_entry_attrs, join_node_attrs, os); |
116 | 195 | Graph ret; |
117 | 196 | ret.attrs["graphir"] = std::make_shared<any>(os.str()); |
118 | 197 | return ret; |
|
0 commit comments