Skip to content

Commit e3e60ea

Browse files
merrymercytmoreau89
authored andcommitted
[AUTOTVM] End2End autotvm support for vta (apache#18)
* support tuning a whole network * pass unit test * update tune resnet * update all
1 parent b4df367 commit e3e60ea

File tree

20 files changed

+873
-906
lines changed

20 files changed

+873
-906
lines changed

python/tvm/autotvm/measure/measure_methods.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,18 @@ def set_task(self, task):
207207
for x in arg_bufs]
208208
func = build(s, arg_bufs, "llvm")
209209
tvm_buf = [nd.array(x) for x in self.ref_input]
210-
func(*tvm_buf)
210+
211+
def _run_func():
212+
"""Run tvm function in a thread.
213+
Because there is some issues with python multiprocessing and the thread pool in tvm
214+
"""
215+
func(*tvm_buf)
216+
217+
thread = threading.Thread(target=_run_func)
218+
thread.start()
219+
thread.join()
220+
del thread
221+
211222
self.ref_output = [x.asnumpy() for x in tvm_buf]
212223

213224
def get_build_kwargs(self):

python/tvm/autotvm/task/nnvm_integration.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66
import warnings
77
import logging
8+
import sys
89

910

1011
from ... import target as _target
@@ -18,8 +19,7 @@
1819
def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
1920
""" Extract tuning tasks from a nnvm graph.
2021
21-
This function collects tuning tasks by building the graph
22-
with a "tracing" target and tracing all the calls to topi.
22+
This function collects tuning tasks by building the graph and trace all the calls to topi.
2323
2424
Parameters
2525
----------
@@ -45,7 +45,7 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
4545
import nnvm
4646
import topi
4747

48-
env = TaskExtractEnv.get()
48+
env = TaskExtractEnv(symbols)
4949

5050
#NOTE: To add more symbols, you only need to change the following lists
5151
#nnvm symbol -> topi compute
@@ -63,26 +63,23 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
6363
else:
6464
warnings.warn("Symbol %s is not tunable, ignored" % sym_name)
6565

66-
# run compiler to collect all TOPI calls during compilation
67-
env.reset(topi_funcs)
66+
# run compiler to collect all TOPI calls during compilation
67+
nnvm.compiler.engine.clear_cache()
68+
nnvm.compiler.build(graph, target=target, shape=shape, dtype=dtype)
69+
nnvm.compiler.engine.clear_cache()
6870

69-
# disable logger temporarily
70-
old_state = logger.disabled
71-
logger.disabled = True
72-
73-
# use a "tracing" target to do a fake compile for collecting topi calls
74-
tracing_target = _target.create("llvm -device=tracing")
75-
nnvm.compiler.engine.clear_cache()
76-
nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)
77-
78-
logger.disabled = old_state
71+
logger.disabled = old_state
7972

8073
# create tasks for target
8174
tasks = []
8275
for task_name, args in env.get_tasks():
83-
tasks.append(create(task_name, args,
84-
target=target, target_host=target_host,
85-
template_key='direct'))
76+
try:
77+
tsk = create(task_name, args,
78+
target=target, target_host=target_host,
79+
template_key='direct')
80+
tasks.append(tsk)
81+
except topi.InvalidShapeError:
82+
print("shape error")
8683

8784
return tasks
8885

python/tvm/target.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,13 @@ def rasp(options=None):
473473
return arm_cpu('rasp3b', options)
474474

475475

476+
def vta(model='unknown', options=None):
477+
opts = ["-device=vta", '-keys=cpu', '-model=%s' % model]
478+
opts = _merge_opts(opts, options)
479+
ret = _api_internal._TargetCreate("ext_dev", *opts)
480+
return ret
481+
482+
476483
def create(target_str):
477484
"""Get a target given target string.
478485

src/codegen/build_module.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Target CreateTarget(const std::string& target_name,
3939

4040
std::string libs_flag = "-libs=";
4141
std::string device_flag = "-device=";
42+
std::string keys_flag = "-keys=";
4243
for (auto& item : options) {
4344
t->options_array.push_back(ir::StringImm::make(item));
4445

@@ -50,12 +51,16 @@ Target CreateTarget(const std::string& target_name,
5051
}
5152
} else if (item.find(device_flag) == 0) {
5253
t->device_name = item.substr(device_flag.length());
54+
t->keys_array.push_back(ir::StringImm::make(t->device_name));
55+
} else if (item.find(keys_flag) == 0) {
56+
std::stringstream ss(item.substr(keys_flag.length()));
57+
std::string key_item;
58+
while (std::getline(ss, key_item, ',')) {
59+
t->keys_array.push_back(ir::StringImm::make(key_item));
60+
}
5361
}
5462
}
5563

56-
if (t->device_name.length() > 0) {
57-
t->keys_array.push_back(ir::StringImm::make(t->device_name));
58-
}
5964
t->device_type = kDLCPU;
6065
t->thread_warp_size = 1;
6166
if (target_name == "c" || target_name == "llvm") {

topi/python/topi/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
from . import image
3535
from . import sparse
3636
from . import hls
37+
38+
# some short cut
39+
from .util import InvalidShapeError
40+
3741
# not import testing by default
3842
# because testing can have extra deps that are not necessary
3943
# we can import them from test cases explicitly

topi/python/topi/util.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
import tvm
77
from . import tag
88

9+
class InvalidShapeError(ValueError):
10+
"""Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)"""
11+
pass
12+
913
def traverse_inline(s, final_op, callback):
1014
"""Traverse computation graph and do auto inline
1115

vta/python/vta/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# to maintain minimum dependency on the board
1919
if sys.argv[0] not in ("-c", "-m"):
2020
from . import top
21-
from .build_module import build_config, lower, build
2221
from . import graph
22+
23+
from .build_module import build_config, lower, build, vta_autotvm_build_func
2324
from .ptr_alias import reinterpret

vta/python/vta/build_module.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,39 @@ def build(*args, **kwargs):
102102
with build_config():
103103
return tvm.build(*args, **kwargs)
104104
return tvm.build(*args, **kwargs)
105+
106+
107+
def vta_autotvm_build_func(measure_input, tmp_dir, **kwargs):
108+
"""Custom build func for VTA. Used for autotvm"""
109+
110+
import time
111+
import os
112+
from random import getrandbits
113+
from tvm.autotvm.util import get_const_tuple
114+
from tvm.autotvm.measure.measure_methods import BuildResult, InstantiationError
115+
116+
tic = time.time()
117+
try:
118+
filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64))
119+
target, task, config = measure_input
120+
121+
with target:
122+
s, args = task.instantiate(config)
123+
if not config.valid():
124+
raise InstantiationError(config.errors)
125+
126+
func = build(s, args, target_host=task.target_host)
127+
func2 = build(s, args)
128+
129+
arg_info = tuple((get_const_tuple(x.shape), x.dtype) for x in args)
130+
func.export_library(filename)
131+
132+
# check by local simulator
133+
ctx = tvm.context(str(target))
134+
args = [tvm.nd.empty(x[0], dtype=x[1], ctx=ctx) for x in arg_info]
135+
func2(*args)
136+
137+
except Exception as e: # pylint: disable=broad-except
138+
return BuildResult(None, None, e, time.time() - tic)
139+
return BuildResult(filename, arg_info, None, time.time() - tic)
140+

vta/python/vta/environment.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -227,34 +227,24 @@ def gemm(self):
227227
"""GEMM intrinsic"""
228228
return self.dev.gemm
229229

230-
# TODO get rid of it
231230
@property
232-
def target_host(self):
233-
"""The target host"""
234-
return "llvm " + self.llvm_triple
231+
def target(self):
232+
return tvm.target.vta(model=self.TARGET)
235233

236234
@property
237-
def target_vta_cpu(self):
235+
def target_host(self):
238236
"""The target host"""
239237
if self.TARGET == "pynq":
240-
return "llvm -device=arm_cpu -model=pynq {}".format(self.llvm_triple)
238+
return "llvm -target=armv7-none-linux-gnueabihf"
241239
elif self.TARGET == "ultra96":
242-
return "llvm -device=arm_cpu -model=ultra96 {}".format(self.llvm_triple)
240+
return "llvm -target=aarch64-linux-gnu"
243241
elif self.TARGET == "sim":
244242
return "llvm"
245243
raise ValueError("Unknown target %s" % self.TARGET)
246244

247245
@property
248-
def llvm_triple(self):
249-
"""The llvm flags for the target platform"""
250-
if self.TARGET == "pynq":
251-
return "-target=armv7-none-linux-gnueabihf"
252-
elif self.TARGET == "ultra96":
253-
return "-target=aarch64-linux-gnu"
254-
elif self.TARGET == "sim":
255-
return ""
256-
else:
257-
raise ValueError("Unknown target %s" % self.TARGET)
246+
def target_vta_cpu(self):
247+
return tvm.target.arm_cpu(model=self.TARGET)
258248

259249

260250
def get_env():

vta/python/vta/top/__init__.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,4 @@
11
"""TVM TOPI connector, eventually most of these should go to TVM repo"""
22

33
from . import vta_conv2d
4-
from . import arm_conv2d
5-
from . import testing
6-
7-
from .bitpack import bitpack
8-
from .vta_dense import packed_dense, schedule_packed_dense
9-
from .vta_conv2d import packed_conv2d, schedule_packed_conv2d
10-
from .vta_group_conv2d import packed_group_conv2d, schedule_packed_group_conv2d
11-
from .vta_conv2d_transpose import packed_conv2d_transpose, schedule_packed_conv2d_transpose
4+
from . import op

0 commit comments

Comments
 (0)