Skip to content

Commit 88163ec

Browse files
ptrendxtqchen
authored andcommitted
Ghost nodes in NNVM graph (#3290)
1 parent 165aa0d commit 88163ec

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

nnvm/include/nnvm/op_attr_types.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,17 @@ using FInferType = FInferNodeEntryAttr<int>;
136136
*/
137137
using TIsBackward = bool;
138138

139+
/*!
140+
* \brief Whether this op is a ghost node.
141+
* If TIsGhost is true:
142+
* - The node with this op will not be visible in the indexed graph.
143+
*
144+
* \note Register under "TIsGhost"
145+
* This enables shape/type inference for backward nodes when
146+
* fusion is present.
147+
*/
148+
using TIsGhost = bool;
149+
139150
/*!
140151
* \brief Get possible inplace options.
141152
* This function enables optimization to reuse memory of inputs in output.

nnvm/src/core/graph.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ IndexedGraph::IndexedGraph(const Graph &g) {
7676

7777
DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs]
7878
(const NodePtr& n) {
79+
const auto& is_ghost = Op::GetAttr<TIsGhost>("TIsGhost");
80+
if (!n->is_variable() && is_ghost.get(n->op(), false)) return;
7981
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
8082
uint32_t nid = static_cast<uint32_t>(nodes_.size());
8183
CHECK(n);
@@ -103,6 +105,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
103105
inputs_rptr.push_back(input_entries_.size());
104106
// control deps
105107
for (const auto& nptr : n->control_deps) {
108+
if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue;
106109
auto it = node2index_.find(nptr.get());
107110
CHECK(it != node2index_.end() && it->first == nptr.get());
108111
control_deps_.push_back(it->second);

0 commit comments

Comments
 (0)