Skip to content

Commit cfc5aff

Browse files
committed
[Collage] SubGraphs
See https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md. Collage works in units of 'sub-graphs', which are potential partitions of the overall Relay model. This PR introduces SubGraph (an arbitrary partitioning, without any implication about how it is to be represented), it's companion SubSubGraph (implying a representation as a function), and some supporting odds 'n ends.
1 parent 6a86c97 commit cfc5aff

File tree

11 files changed

+2607
-0
lines changed

11 files changed

+2607
-0
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ tvm_file_glob(GLOB_RECURSE RELAY_OP_SRCS
295295
)
296296
tvm_file_glob(GLOB_RECURSE RELAY_PASS_SRCS
297297
src/relay/analysis/*.cc
298+
src/relay/collage/*.cc
298299
src/relay/transforms/*.cc
299300
src/relay/quantize/*.cc
300301
)

src/relay/collage/README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
2+
<!--- or more contributor license agreements. See the NOTICE file -->
3+
<!--- distributed with this work for additional information -->
4+
<!--- regarding copyright ownership. The ASF licenses this file -->
5+
<!--- to you under the Apache License, Version 2.0 (the -->
6+
<!--- "License"); you may not use this file except in compliance -->
7+
<!--- with the License. You may obtain a copy of the License at -->
8+
9+
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
10+
11+
<!--- Unless required by applicable law or agreed to in writing, -->
12+
<!--- software distributed under the License is distributed on an -->
13+
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
14+
<!--- KIND, either express or implied. See the License for the -->
15+
<!--- specific language governing permissions and limitations -->
16+
<!--- under the License. -->
17+
18+
The `CollagePartition` pass for finding optimal partitionings of Relay models.
19+
20+
See the [RFC](https://github.com/mbs-octoml/mbs-tvm-rfcs/blob/mbs-rfcs-collage/rfcs/xxxx-collage.md).
21+
22+
Based on:
23+
> *Collage: Automated Integration of Deep Learning Backends*
24+
> Byungsoo Jeon, Sunghyun Park, Peiyuan Liao, Sheng Xu, Tianqi Chen, Zhihao Jia
25+
26+
CAUTION: This is a prototype, do not use in prod.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file src/relay/collage/dataflow_graph.cc
22+
* \brief A representation of the dataflow for an overall Relay expression.
23+
*/
24+
25+
#include "./dataflow_graph.h"
26+
27+
namespace tvm {
28+
namespace relay {
29+
namespace collage {
30+
31+
DataflowGraph::DataflowGraph(Expr expr) : expr_(std::move(expr)) {
32+
indexed_graph_ = CreateIndexedGraph(expr_);
33+
downstream_map_.reserve(indexed_graph_->size());
34+
for (PostDfsIndex index = 0; index < indexed_graph_->size(); ++index) {
35+
const Node* node = indexed_graph_->index_to_node(index);
36+
std::unordered_set<const Node*> downstream_nodes;
37+
node->AccumulateDownstreamNodes(&downstream_nodes);
38+
IndexSet index_set(indexed_graph_->size());
39+
for (const Node* downstream_node : downstream_nodes) {
40+
index_set.Add(downstream_node->index_);
41+
}
42+
downstream_map_.emplace_back(std::move(index_set));
43+
}
44+
}
45+
46+
} // namespace collage
47+
} // namespace relay
48+
} // namespace tvm

src/relay/collage/dataflow_graph.h

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file src/relay/collage/dataflow_graph.h
22+
* \brief A representation of the dataflow for an overall Relay expression.
23+
*/
24+
#ifndef TVM_RELAY_COLLAGE_DATAFLOW_GRAPH_H_
25+
#define TVM_RELAY_COLLAGE_DATAFLOW_GRAPH_H_
26+
27+
#include <tvm/relay/expr.h>
28+
29+
#include <memory>
30+
#include <vector>
31+
32+
#include "../ir/indexed_graph.h"
33+
#include "./index_set.h"
34+
35+
namespace tvm {
36+
namespace relay {
37+
namespace collage {
38+
39+
/*!
40+
* \brief Represents the dataflow of an overall Relay expression.
41+
*/
42+
class DataflowGraph {
43+
public:
44+
using Node = IndexedGraph<Expr>::Node;
45+
46+
explicit DataflowGraph(Expr expr);
47+
48+
size_t size() const { return indexed_graph_->size(); }
49+
const Node* index_to_node(PostDfsIndex index) const {
50+
return indexed_graph_->index_to_node(index);
51+
}
52+
const Node* item_to_node(const Expr& expr) const { return indexed_graph_->item_to_node(expr); }
53+
const Node* item_to_node(const ExprNode* expr_node) const {
54+
return indexed_graph_->item_to_node(expr_node);
55+
}
56+
const Expr& expr() const { return expr_; }
57+
const IndexedGraph<Expr>& indexed_graph() const { return *indexed_graph_; }
58+
59+
const IndexSet& downstream_of(PostDfsIndex index) const {
60+
ICHECK_LT(index, indexed_graph_->size());
61+
return downstream_map_[index];
62+
}
63+
64+
private:
65+
/*! \brief The overall expression. */
66+
Expr expr_;
67+
/*! \brief The indexed graph which captures the main dataflow. */
68+
std::unique_ptr<IndexedGraph<Expr>> indexed_graph_;
69+
/*! \brief Map from a node's PostDfsIndex to the set of it's downstream dataflow node indexes. */
70+
std::vector<IndexSet> downstream_map_;
71+
};
72+
73+
} // namespace collage
74+
} // namespace relay
75+
} // namespace tvm
76+
77+
#endif // TVM_RELAY_COLLAGE_DATAFLOW_GRAPH_H_

src/relay/collage/index_set.cc

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file src/relay/collage/index_set.cc
22+
* \brief Efficient representation of a set of post-dfs indexes.
23+
*/
24+
25+
#include "./index_set.h"
26+
27+
namespace tvm {
28+
namespace relay {
29+
namespace collage {
30+
31+
// TODO(mbs): These should operate one-word-at-a-time
32+
33+
IndexSet::IndexSet(size_t size, const std::vector<size_t>& indexes) : bitvec_(size, false) {
34+
for (size_t index : indexes) {
35+
ICHECK_LT(index, bitvec_.size());
36+
ICHECK(!bitvec_[index]);
37+
bitvec_[index] = true;
38+
}
39+
}
40+
41+
IndexSet IndexSet::operator&(const IndexSet& that) const {
42+
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
43+
std::vector<bool> result(bitvec_.size(), false);
44+
for (size_t index = 0; index < bitvec_.size(); ++index) {
45+
result[index] = bitvec_[index] && that.bitvec_[index];
46+
}
47+
return IndexSet(result);
48+
}
49+
50+
IndexSet IndexSet::operator|(const IndexSet& that) const {
51+
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
52+
std::vector<bool> result(bitvec_.size(), false);
53+
for (size_t index = 0; index < bitvec_.size(); ++index) {
54+
result[index] = bitvec_[index] || that.bitvec_[index];
55+
}
56+
return IndexSet(result);
57+
}
58+
59+
IndexSet IndexSet::operator-(const IndexSet& that) const {
60+
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
61+
std::vector<bool> result(bitvec_.size());
62+
for (size_t index = 0; index < bitvec_.size(); ++index) {
63+
result[index] = bitvec_[index] && !that.bitvec_[index];
64+
}
65+
return IndexSet(result);
66+
}
67+
68+
bool IndexSet::AreDisjoint(const IndexSet& that) const {
69+
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
70+
for (size_t index = 0; index < bitvec_.size(); index++) {
71+
if (bitvec_[index] && that.bitvec_[index]) {
72+
return false;
73+
}
74+
}
75+
return true;
76+
}
77+
78+
bool IndexSet::IsSubset(const IndexSet& that) const {
79+
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
80+
for (size_t index = 0; index < bitvec_.size(); index++) {
81+
if (bitvec_[index] && !that.bitvec_[index]) {
82+
return false;
83+
}
84+
}
85+
return true;
86+
}
87+
88+
bool IndexSet::Intersects(const IndexSet& that) const {
89+
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
90+
for (size_t index = 0; index < bitvec_.size(); index++) {
91+
if (bitvec_[index] && that.bitvec_[index]) {
92+
return true;
93+
}
94+
}
95+
return false;
96+
}
97+
98+
IndexSet IndexSet::Subst(size_t new_size, const IndexSubst& subst) const {
99+
std::vector<bool> result(new_size, false);
100+
for (PostDfsIndex index = 0; index < bitvec_.size(); ++index) {
101+
if (!bitvec_[index]) {
102+
continue;
103+
}
104+
auto itr = subst.find(index);
105+
ICHECK(itr != subst.end());
106+
PostDfsIndex new_index = itr->second;
107+
ICHECK(new_index < new_size);
108+
ICHECK(!result[new_index]);
109+
result[new_index] = true;
110+
}
111+
return IndexSet(result);
112+
}
113+
114+
size_t IndexSet::PopCount() const {
115+
size_t n = 0;
116+
for (size_t index = 0; index < bitvec_.size(); index++) {
117+
if (bitvec_[index]) {
118+
++n;
119+
}
120+
}
121+
return n;
122+
}
123+
124+
bool IndexSet::IsZero() const {
125+
for (size_t index = 0; index < bitvec_.size(); index++) {
126+
if (bitvec_[index]) {
127+
return false;
128+
}
129+
}
130+
return true;
131+
}
132+
133+
size_t IndexSet::FirstInsideIndex() const {
134+
for (size_t index = 0; index < bitvec_.size(); index++) {
135+
if (bitvec_[index]) {
136+
return index;
137+
}
138+
}
139+
return bitvec_.size();
140+
}
141+
142+
size_t IndexSet::LastInsideIndex() const {
143+
for (size_t i = bitvec_.size(); i > 0; i--) {
144+
const size_t index = i - 1;
145+
if (bitvec_[index]) {
146+
return index;
147+
}
148+
}
149+
return bitvec_.size();
150+
}
151+
152+
size_t IndexSet::NextIndex(size_t index) const {
153+
ICHECK_LT(index, bitvec_.size());
154+
for (index++; index < bitvec_.size(); index++) {
155+
if (bitvec_[index]) {
156+
return index;
157+
}
158+
}
159+
return bitvec_.size();
160+
}
161+
162+
size_t IndexSet::FirstOutsideIndex() const {
163+
for (size_t index = 0; index < bitvec_.size(); index++) {
164+
if (!bitvec_[index]) {
165+
return index;
166+
}
167+
}
168+
return bitvec_.size();
169+
}
170+
171+
bool IndexSet::operator==(const IndexSet& that) const {
172+
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
173+
return bitvec_ == that.bitvec_;
174+
}
175+
176+
bool IndexSet::operator!=(const IndexSet& that) const {
177+
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
178+
return bitvec_ != that.bitvec_;
179+
}
180+
181+
bool IndexSet::operator<(const IndexSet& that) const {
182+
ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
183+
for (size_t index = 0; index < bitvec_.size(); index++) {
184+
if (bitvec_[index] && !that.bitvec_[index]) {
185+
return true;
186+
}
187+
if (!bitvec_[index] && that.bitvec_[index]) {
188+
return false;
189+
}
190+
}
191+
return false;
192+
}
193+
194+
size_t IndexSet::hash() const {
195+
std::hash<std::vector<bool>> h;
196+
return h(bitvec_);
197+
}
198+
199+
std::string IndexSet::ToString() const {
200+
std::ostringstream os;
201+
os << "{";
202+
bool first = true;
203+
for (size_t start = 0; start < bitvec_.size(); /*no-op*/) {
204+
if (!bitvec_[start]) {
205+
++start;
206+
continue;
207+
}
208+
size_t end;
209+
for (end = start + 1; end < bitvec_.size() && bitvec_[end]; ++end) {
210+
/*no-op*/
211+
}
212+
if (first) {
213+
first = false;
214+
} else {
215+
os << ",";
216+
}
217+
os << start;
218+
if (end > start + 2) {
219+
os << ".." << (end - 1);
220+
start = end;
221+
} else {
222+
++start;
223+
}
224+
}
225+
os << "}";
226+
return os.str();
227+
}
228+
229+
} // namespace collage
230+
} // namespace relay
231+
} // namespace tvm

0 commit comments

Comments
 (0)