Skip to content

Commit 7055803

Browse files
merrymercytmoreau89
authored andcommitted
[AUTOTVM] End2End autotvm support for vta (apache#18)
* support tuning a whole network * pass unit test * update tune resnet * update all
1 parent 830c532 commit 7055803

File tree

20 files changed

+924
-1160
lines changed

20 files changed

+924
-1160
lines changed

python/tvm/autotvm/measure/measure_methods.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,18 @@ def set_task(self, task):
207207
for x in arg_bufs]
208208
func = build(s, arg_bufs, "llvm")
209209
tvm_buf = [nd.array(x) for x in self.ref_input]
210-
func(*tvm_buf)
210+
211+
def _run_func():
212+
"""Run tvm function in a thread.
213+
Because there is some issues with python multiprocessing and the thread pool in tvm
214+
"""
215+
func(*tvm_buf)
216+
217+
thread = threading.Thread(target=_run_func)
218+
thread.start()
219+
thread.join()
220+
del thread
221+
211222
self.ref_output = [x.asnumpy() for x in tvm_buf]
212223

213224
def get_build_kwargs(self):

python/tvm/autotvm/task/nnvm_integration.py

Lines changed: 65 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66
import warnings
77
import logging
8+
import sys
89

910

1011
from ... import tensor, placeholder, create_schedule, target as _target
@@ -49,9 +50,9 @@ def deserialize_args(args):
4950
# Task extractor for nnvm graph
5051
class TaskExtractEnv:
5152
"""Global environment for extracting tuning tasks from nnvm graph"""
52-
current = None
53+
registered = False
5354

54-
def __init__(self):
55+
def __init__(self, wanted_symbols):
5556
import topi
5657
import nnvm
5758

@@ -83,46 +84,62 @@ def __init__(self):
8384
topi.nn.dense: [topi.generic.schedule_dense],
8485
}
8586

86-
self._register_tracing()
87+
# support reflection for tracing
88+
self.func_to_reflection = {
89+
topi.nn.conv2d: lambda x: setattr(topi.nn, 'conv2d', x),
90+
topi.nn.depthwise_conv2d_nchw: lambda x: setattr(topi.nn, 'depthwise_conv2d_nchw', x),
91+
topi.nn.conv2d_transpose_nchw: lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x),
92+
topi.nn.dense: lambda x: setattr(topi.nn, 'dense', x),
93+
}
94+
95+
96+
self.wanted_topi_funcs = []
97+
for sym_name in wanted_symbols:
98+
if sym_name in self.symbol2topi:
99+
self.wanted_topi_funcs.extend(self.symbol2topi[sym_name])
100+
else:
101+
warnings.warn("Symbol %s is not tunable, ignored" % sym_name)
102+
87103
self._register_topi_task()
88104
self.task_collection = []
89-
self.wanted_topi_funcs = list(self.topi_to_task.keys())
105+
self.modified_funcs = []
90106

91-
def _register_tracing(self):
92-
"""Register tracing function to track the topi function call"""
93-
# register topi compute for "tracing" target
94-
for topi_compute in self.topi_to_task:
107+
def __enter__(self):
108+
self.task_collection = []
109+
self.modified_funcs = []
110+
111+
for topi_compute in self.wanted_topi_funcs:
95112
def _local_scope(compute_func):
96113
"""start a scope to hold the local function in for loop"""
97114

98-
@compute_func.register("tracing", )
99-
def _tracing_topi_compute(*args, **kwargs):
100-
assert not kwargs, "Do not support extracting tuning tasks when" \
101-
"kwargs is used in TOPI function call." \
115+
def _tracing_wrapper(*args, **kwargs):
116+
assert not kwargs, "Do not support extracting tuning tasks when " \
117+
"kwargs is used in TOPI function call. " \
102118
"Please modify it to use only positional args."
103119

104-
if compute_func in self.wanted_topi_funcs: # record this call
105-
key = (self.topi_to_task[compute_func], serialize_args(args))
106-
if key not in self.task_collection:
107-
self.task_collection.append(key)
120+
key = (self.topi_to_task[compute_func], serialize_args(args))
121+
if key not in self.task_collection:
122+
self.task_collection.append(key)
123+
124+
return compute_func(*args, **kwargs)
125+
126+
self.func_to_reflection[topi_compute](_tracing_wrapper)
127+
self.modified_funcs.append(topi_compute)
108128

109-
return compute_func.fdefault(*args)
110129
_local_scope(topi_compute)
111130

112-
# register topi schedule for "tracing" target
113-
for topi_compute in self.topi_to_task:
114-
for topi_schedule in self.topi_to_schedule[topi_compute]:
115-
def _local_scope_(schedule_func):
116-
"""start a scope to hold the local function in for loop"""
131+
return self
117132

118-
@schedule_func.register("tracing", )
119-
def _tracing_topi_compute(outs):
120-
outs = [outs] if isinstance(outs, tensor.Tensor) else outs
121-
return create_schedule([x.op for x in outs])
122-
_local_scope_(topi_schedule)
133+
def __exit__(self, exc_type, exc_val, exc_tb):
134+
# revert modification
135+
for func in self.modified_funcs:
136+
self.func_to_reflection[func](func)
123137

124138
def _register_topi_task(self):
125139
"""register tuning wrapper for topi function"""
140+
if TaskExtractEnv.registered:
141+
return
142+
TaskExtractEnv.registered = True
126143
import topi
127144

128145
# Tuning wrapper for topi functions
@@ -175,17 +192,6 @@ def _topi_nn_dense(*args, **kwargs):
175192
return s, [data, weight, bias, C]
176193
return s, [data, weight, C]
177194

178-
def reset(self, wanted_topi_funcs):
179-
"""Reset task collections
180-
181-
Parameters
182-
----------
183-
wanted_topi_funcs: List of function
184-
The topi function to be extracted
185-
"""
186-
self.task_collection = []
187-
self.wanted_topi_funcs = wanted_topi_funcs
188-
189195
def get_tasks(self):
190196
"""Get collected tasks
191197
@@ -196,25 +202,11 @@ def get_tasks(self):
196202
"""
197203
return self.task_collection
198204

199-
@staticmethod
200-
def get():
201-
"""Get the single instance of TaskExtractEnv
202-
203-
Returns
204-
-------
205-
env: TaskExtractEnv
206-
The single instance of TaskExtractEnv
207-
"""
208-
if not TaskExtractEnv.current:
209-
TaskExtractEnv.current = TaskExtractEnv()
210-
return TaskExtractEnv.current
211-
212205

213206
def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
214207
""" Extract tuning tasks from a nnvm graph.
215208
216-
This function collects tuning tasks by building the graph
217-
with a "tracing" target and tracing all the calls to topi.
209+
This function collects tuning tasks by building the graph and trace all the calls to topi.
218210
219211
Parameters
220212
----------
@@ -237,97 +229,34 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
237229
collected tasks
238230
"""
239231
import nnvm.compiler
232+
import topi
240233

241-
env = TaskExtractEnv.get()
234+
env = TaskExtractEnv(symbols)
242235

243-
topi_funcs = []
244-
for sym_name in symbols:
245-
if sym_name in env.symbol2topi:
246-
topi_funcs.extend(env.symbol2topi[sym_name])
247-
else:
248-
warnings.warn("Symbol %s is not tunable, ignored" % sym_name)
236+
with env:
237+
# disable logger temporarily
238+
old_state = logger.disabled
239+
logger.disabled = True
249240

250-
# run compiler to collect all TOPI calls during compilation
251-
env.reset(topi_funcs)
241+
# run compiler to collect all TOPI calls during compilation
242+
nnvm.compiler.engine.clear_cache()
243+
nnvm.compiler.build(graph, target=target, shape=shape, dtype=dtype)
244+
nnvm.compiler.engine.clear_cache()
252245

253-
# disable logger temporarily
254-
old_state = logger.disabled
255-
logger.disabled = True
256-
257-
# use a "tracing" target to do a fake compile for collecting topi calls
258-
tracing_target = _target.create("llvm -device=tracing")
259-
nnvm.compiler.engine.clear_cache()
260-
nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)
261-
262-
logger.disabled = old_state
246+
logger.disabled = old_state
263247

264248
# create tasks for target
265249
tasks = []
266250
for task_name, args in env.get_tasks():
267-
tasks.append(create(task_name, args,
268-
target=target, target_host=target_host,
269-
template_key='direct'))
251+
try:
252+
tsk = create(task_name, args,
253+
target=target, target_host=target_host,
254+
template_key='direct')
255+
tasks.append(tsk)
256+
except topi.InvalidShapeError:
257+
print("shape error")
270258

271259
return tasks
272260

273-
274-
def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_host=None):
275-
""" Extract tuning tasks from multiple nnvm graphs.
276-
277-
This function is the multiple graph version of extract_from_graph
278-
279-
Parameters
280-
----------
281-
graphs : List of Graph
282-
The list of graphs to tune
283-
shapes : List of dict of str to tuple
284-
The input shape to the graph
285-
dtypes : List of str or dict of str to str
286-
The input types to the graph
287-
target: tvm.target.Target
288-
The compilation target
289-
symbols : Array of nnvm.symbol
290-
Array of nnvm symbols want to be tuned
291-
target_host: tvm.target.Target
292-
The host compilation target
293-
294-
Returns
295-
-------
296-
task: Array of autotvm.task.Task
297-
collected tasks
298-
"""
299-
import nnvm.compiler
300-
301-
env = TaskExtractEnv.get()
302-
303-
topi_funcs = []
304-
for sym_name in symbols:
305-
if sym_name in env.symbol2topi:
306-
topi_funcs.extend(env.symbol2topi[sym_name])
307-
else:
308-
warnings.warn("Symbol %s is not tunable, ignored" % sym_name)
309-
310-
# run compiler to collect all TOPI calls during compilation
311-
env.reset(topi_funcs)
312-
313-
# disable logger temporarily
314-
old_state = logger.disabled
315-
logger.disabled = True
316-
317-
# use a "tracing" target to do a fake compile for collecting topi calls
318-
tracing_target = _target.create("llvm -device=tracing")
319-
320-
nnvm.compiler.engine.clear_cache()
321-
for graph, shape, dtype in zip(graphs, shapes, dtypes):
322-
nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)
323-
324-
logger.disabled = old_state
325-
326-
# create tasks for target
327-
tasks = []
328-
for task_name, args in env.get_tasks():
329-
tasks.append(create(task_name, args,
330-
target=target, target_host=target_host,
331-
template_key='direct'))
332-
333-
return tasks
261+
def extract_from_multiple_graph(graph, shape, dtype, target, symbols, target_host=None):
262+
pass

python/tvm/target.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,13 @@ def rasp(options=None):
473473
return arm_cpu('rasp3b', options)
474474

475475

476+
def vta(model='unknown', options=None):
477+
opts = ["-device=vta", '-keys=cpu', '-model=%s' % model]
478+
opts = _merge_opts(opts, options)
479+
ret = _api_internal._TargetCreate("ext_dev", *opts)
480+
return ret
481+
482+
476483
def create(target_str):
477484
"""Get a target given target string.
478485

src/codegen/build_module.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Target CreateTarget(const std::string& target_name,
3939

4040
std::string libs_flag = "-libs=";
4141
std::string device_flag = "-device=";
42+
std::string keys_flag = "-keys=";
4243
for (auto& item : options) {
4344
t->options_array.push_back(ir::StringImm::make(item));
4445

@@ -50,12 +51,16 @@ Target CreateTarget(const std::string& target_name,
5051
}
5152
} else if (item.find(device_flag) == 0) {
5253
t->device_name = item.substr(device_flag.length());
54+
t->keys_array.push_back(ir::StringImm::make(t->device_name));
55+
} else if (item.find(keys_flag) == 0) {
56+
std::stringstream ss(item.substr(keys_flag.length()));
57+
std::string key_item;
58+
while (std::getline(ss, key_item, ',')) {
59+
t->keys_array.push_back(ir::StringImm::make(key_item));
60+
}
5361
}
5462
}
5563

56-
if (t->device_name.length() > 0) {
57-
t->keys_array.push_back(ir::StringImm::make(t->device_name));
58-
}
5964
t->device_type = kDLCPU;
6065
t->thread_warp_size = 1;
6166
if (target_name == "llvm") {

topi/python/topi/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
from . import image
3535
from . import sparse
3636
from . import hls
37+
38+
# some short cut
39+
from .util import InvalidShapeError
40+
3741
# not import testing by default
3842
# because testing can have extra deps that are not necessary
3943
# we can import them from test cases explicitly

topi/python/topi/util.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
import tvm
77
from . import tag
88

9+
class InvalidShapeError(ValueError):
10+
"""Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)"""
11+
pass
12+
913
def traverse_inline(s, final_op, callback):
1014
"""Traverse computation graph and do auto inline
1115

vta/python/vta/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# to maintain minimum dependency on the board
1919
if sys.argv[0] not in ("-c", "-m"):
2020
from . import top
21-
from .build_module import build_config, lower, build
2221
from . import graph
22+
23+
from .build_module import build_config, lower, build, vta_autotvm_build_func
2324
from .ptr_alias import reinterpret

0 commit comments

Comments
 (0)