|
| 1 | +/*! |
| 2 | + * Copyright (c) 2017 by Contributors |
| 3 | + * \file graph_transform.h |
| 4 | + * \brief A mutator class that does local pattern matching and mutates a node. |
| 5 | +*/ |
| 6 | +#ifndef NNVM_COMPILER_GRAPH_TRANSFORM_H_ |
| 7 | +#define NNVM_COMPILER_GRAPH_TRANSFORM_H_ |
| 8 | + |
| 9 | +#include <nnvm/graph.h> |
| 10 | +#include <vector> |
| 11 | + |
| 12 | +namespace nnvm { |
| 13 | +namespace compiler { |
| 14 | + |
| 15 | +/*! |
| 16 | + * \brief Transform the graph to build a new Graph, in post DFS order. |
| 17 | + * |
| 18 | + * Automatically copies node when some of its children or control_deps changed. |
| 19 | + * This function won't be called in Variable. |
| 20 | + * |
| 21 | + * \param graph The original graph |
| 22 | + * |
| 23 | + * \param ftransform Function of (int nid, const Node* node, std::vector<NodeEntry>* out) -> bool |
| 24 | + * |
| 25 | + * If empty vector is returned, it means original entries should be kept. |
| 26 | + * |
| 27 | + * \tparam FTransform The transformation function. |
| 28 | + */ |
| 29 | +template<typename FTransform> |
| 30 | +Graph GraphTransform(Graph graph, FTransform ftransform) { |
| 31 | + const IndexedGraph& idx = graph.indexed_graph(); |
| 32 | + // new nodes |
| 33 | + std::vector<NodeEntry> new_entry_map(idx.num_node_entries()); |
| 34 | + std::vector<bool> updated(idx.num_node_entries(), false); |
| 35 | + |
| 36 | + // setup inputs and placeholder. |
| 37 | + for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { |
| 38 | + const auto& inode = idx[nid]; |
| 39 | + if (inode.source->is_variable()) continue; |
| 40 | + bool need_copy = false; |
| 41 | + for (const IndexedGraph::NodeEntry& e : inode.inputs) { |
| 42 | + if (updated[idx.entry_id(e)]) { |
| 43 | + need_copy = true; break; |
| 44 | + } |
| 45 | + } |
| 46 | + if (!need_copy) { |
| 47 | + for (const uint32_t cid : inode.control_deps) { |
| 48 | + const auto& cnode = idx[cid]; |
| 49 | + for (uint32_t i = 0 ; i < cnode.source->num_outputs(); ++i) { |
| 50 | + if (updated[idx.entry_id(cid, i)]) { |
| 51 | + need_copy = true; |
| 52 | + } |
| 53 | + } |
| 54 | + if (need_copy) break; |
| 55 | + } |
| 56 | + } |
| 57 | + |
| 58 | + if (!need_copy) { |
| 59 | + std::vector<NodeEntry> ret; |
| 60 | + if (ftransform(nid, inode.source, &ret)) { |
| 61 | + CHECK_EQ(ret.size(), static_cast<size_t>(inode.source->num_outputs())); |
| 62 | + for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) { |
| 63 | + updated[idx.entry_id(nid, i)] = true; |
| 64 | + new_entry_map[idx.entry_id(nid, i)] = ret[i]; |
| 65 | + } |
| 66 | + } |
| 67 | + } else { |
| 68 | + NodePtr node = Node::Create(); |
| 69 | + node->attrs = inode.source->attrs; |
| 70 | + for (size_t i = 0; i < inode.inputs.size(); ++i) { |
| 71 | + const IndexedGraph::NodeEntry& e = inode.inputs[i]; |
| 72 | + if (updated[idx.entry_id(e)]) { |
| 73 | + node->inputs.push_back(new_entry_map[idx.entry_id(e)]); |
| 74 | + } else { |
| 75 | + node->inputs.push_back(inode.source->inputs[i]); |
| 76 | + } |
| 77 | + } |
| 78 | + for (size_t i = 0; i < inode.control_deps.size(); ++i) { |
| 79 | + const uint32_t cid = inode.control_deps[i]; |
| 80 | + const auto& cnode = idx[cid]; |
| 81 | + CHECK_NE(cnode.source->num_outputs(), 0U); |
| 82 | + NodePtr selected_ptr; |
| 83 | + for (uint32_t j = 0 ; j < cnode.source->num_outputs(); ++j) { |
| 84 | + NodePtr cptr = updated[idx.entry_id(cid, j)] ? |
| 85 | + new_entry_map[idx.entry_id(cid, j)].node : inode.source->control_deps[i]; |
| 86 | + if (selected_ptr == nullptr) { |
| 87 | + selected_ptr = std::move(cptr); |
| 88 | + } else { |
| 89 | + CHECK(selected_ptr.get() == cptr.get()) |
| 90 | + << "Control dependency node changed to more than one node"; |
| 91 | + } |
| 92 | + } |
| 93 | + node->control_deps.push_back(selected_ptr); |
| 94 | + } |
| 95 | + std::vector<NodeEntry> ret; |
| 96 | + if (ftransform(nid, node.get(), &ret)) { |
| 97 | + CHECK_EQ(ret.size(), static_cast<size_t>(inode.source->num_outputs())); |
| 98 | + for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) { |
| 99 | + updated[idx.entry_id(nid, i)] = true; |
| 100 | + new_entry_map[idx.entry_id(nid, i)] = ret[i]; |
| 101 | + } |
| 102 | + } else { |
| 103 | + for (uint32_t i = 0 ; i < inode.source->num_outputs(); ++i) { |
| 104 | + updated[idx.entry_id(nid, i)] = true; |
| 105 | + new_entry_map[idx.entry_id(nid, i)] = NodeEntry{node, i, 0}; |
| 106 | + } |
| 107 | + } |
| 108 | + } |
| 109 | + } |
| 110 | + Graph ret; |
| 111 | + for (size_t i = 0; i < idx.outputs().size(); ++i) { |
| 112 | + const IndexedGraph::NodeEntry& e = idx.outputs()[i]; |
| 113 | + if (updated[idx.entry_id(e)]) { |
| 114 | + ret.outputs.push_back(new_entry_map[idx.entry_id(e)]); |
| 115 | + } else { |
| 116 | + ret.outputs.push_back(graph.outputs[i]); |
| 117 | + } |
| 118 | + } |
| 119 | + return ret; |
| 120 | +} |
| 121 | + |
| 122 | +} // namespace compiler |
| 123 | +} // namespace nnvm |
| 124 | + |
| 125 | +#endif // NNVM_COMPILER_GRAPH_TRANSFORM_H_ |
0 commit comments