Skip to content

Commit 1ce9417

Browse files
merrymercytmoreau89
authored andcommitted
update single op example (apache#17)
1 parent e467b5e commit 1ce9417

File tree

2 files changed

+222
-0
lines changed

2 files changed

+222
-0
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/bin/bash
2+
PROJROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/../../" && pwd )"
3+
4+
export PYTHONPATH=${PYTHONPATH}:${PROJROOT}/python:${PROJROOT}/vta/python
5+
python3.6 -m vta.exec.rpc_server --tracker fleet:9190 --key ultra96

vta/scripts/tune_conv.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""Tuning a conv2d operator """
2+
import tvm
3+
import sys
4+
import logging
5+
from tvm import autotvm
6+
from tvm.contrib.util import get_lower_ir
7+
import topi
8+
9+
import vta
10+
import vta.testing
11+
from vta.top.testing import my_clip
12+
13+
env = vta.get_env()
14+
15+
def vta_build_func(measure_input, tmp_dir, **kwargs):
16+
import time
17+
import os
18+
from tvm.autotvm.measure.measure_methods import BuildResult
19+
from random import getrandbits
20+
from tvm.autotvm.util import get_const_tuple
21+
tic = time.time()
22+
try:
23+
filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64))
24+
target, task, config = measure_input
25+
26+
with target:
27+
s, args = task.instantiate(config)
28+
if not config.valid():
29+
raise InstantiationError(config.errors)
30+
31+
func = vta.build(s, args, target='ext_dev', target_host=task.target_host)
32+
33+
arg_info = tuple((get_const_tuple(x.shape), x.dtype) for x in args)
34+
func.export_library(filename)
35+
except Exception as e: # pylint: disable=broad-except
36+
return BuildResult(None, None, e, time.time() - tic)
37+
return BuildResult(filename, arg_info, None, time.time() - tic)
38+
39+
40+
def schedule_packed_conv2d(cfg, outs,
41+
skip_load_inp=False, skip_load_wgt=False, skip_load_acc=False,
42+
skip_store_out=False, skip_alu=False, skip_gemm=False):
43+
"""Schedule the packed conv2d.
44+
"""
45+
assert len(outs) == 1
46+
output = outs[0]
47+
ewise_inputs = []
48+
ewise_ops = []
49+
conv2d_res = []
50+
assert output.op.input_tensors[0].dtype == "int32"
51+
52+
def _traverse(op):
53+
if topi.tag.is_broadcast(op.tag):
54+
if not op.same_as(output.op):
55+
ewise_ops.append(op)
56+
for tensor in op.input_tensors:
57+
if isinstance(tensor.op, tvm.tensor.PlaceholderOp):
58+
ewise_inputs.append((op, tensor))
59+
else:
60+
_traverse(tensor.op)
61+
else:
62+
assert op.tag == "packed_conv2d"
63+
conv2d_res.append(op)
64+
65+
_traverse(output.op)
66+
assert len(conv2d_res) == 1
67+
conv2d_stage = conv2d_res[0].output(0)
68+
s = tvm.create_schedule(output.op)
69+
70+
##### space definition begin #####
71+
b, co, h, w, bi, ci = s[conv2d_stage].op.axis
72+
ci, kh, kw, bci = s[conv2d_stage].op.reduce_axis
73+
cfg.define_split('tile_b', b, num_outputs=2)
74+
cfg.define_split('tile_h', h, num_outputs=2)
75+
cfg.define_split('tile_w', w, num_outputs=2)
76+
cfg.define_split('tile_ci', ci, num_outputs=2)
77+
cfg.define_split('tile_co', co, num_outputs=2)
78+
cfg.define_knob('oc_nthread', [1, 2])
79+
cfg.define_knob('h_nthread', [1, 2])
80+
###### space definition end ######
81+
82+
data, kernel = conv2d_stage.op.input_tensors
83+
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
84+
temp = data.op.input_tensors[0]
85+
pad_data = data
86+
data = temp
87+
else:
88+
pad_data = None
89+
90+
mock = env.mock
91+
load_inp = mock.dma_copy if skip_load_inp else env.dma_copy
92+
load_wgt = mock.dma_copy if skip_load_wgt else env.dma_copy
93+
load_acc = mock.dma_copy if skip_load_acc else env.dma_copy
94+
store_out = mock.dma_copy if skip_store_out else env.dma_copy
95+
alu = mock.alu if skip_alu else env.alu
96+
gemm = mock.gemm if skip_gemm else env.gemm
97+
98+
# schedule
99+
oshape = topi.util.get_const_tuple(output.shape)
100+
101+
# setup pad
102+
if pad_data is not None:
103+
cdata = pad_data
104+
s[pad_data].set_scope(env.inp_scope)
105+
else:
106+
cdata = s.cache_read(data, env.inp_scope, [conv2d_stage])
107+
ckernel = s.cache_read(kernel, env.wgt_scope, [conv2d_stage])
108+
s[conv2d_stage].set_scope(env.acc_scope)
109+
110+
# cache read input
111+
cache_read_ewise = []
112+
for consumer, tensor in ewise_inputs:
113+
cache_read_ewise.append(
114+
s.cache_read(tensor, env.acc_scope, [consumer]))
115+
116+
# set ewise scope
117+
for op in ewise_ops:
118+
s[op].set_scope(env.acc_scope)
119+
s[op].pragma(s[op].op.axis[0], alu)
120+
121+
# tile
122+
x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis
123+
x_co0, x_co1 = cfg['tile_co'].apply(s, output, x_co)
124+
x_i0, x_i1 = cfg['tile_h'].apply(s, output, x_i)
125+
x_j0, x_j1 = cfg['tile_w'].apply(s, output, x_j)
126+
s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci)
127+
store_pt = x_j0
128+
129+
# set all compute scopes
130+
s[conv2d_stage].compute_at(s[output], store_pt)
131+
for op in ewise_ops:
132+
s[op].compute_at(s[output], store_pt)
133+
134+
for tensor in cache_read_ewise:
135+
s[tensor].compute_at(s[output], store_pt)
136+
s[tensor].pragma(s[tensor].op.axis[0], load_acc)
137+
138+
# virtual threading along output channel axes
139+
if cfg['oc_nthread'].val > 1:
140+
_, v_t = s[output].split(x_co0, factor=cfg['oc_nthread'].val)
141+
s[output].reorder(v_t, x_bo)
142+
s[output].bind(v_t, tvm.thread_axis("cthread"))
143+
144+
# virtual threading along spatial rows
145+
if cfg['h_nthread'].val > 1:
146+
_, v_t = s[output].split(x_i0, factor=cfg['h_nthread'].val)
147+
s[output].reorder(v_t, x_bo)
148+
s[output].bind(v_t, tvm.thread_axis("cthread"))
149+
150+
x_bo, x_co, x_i, x_j, x_bi, x_ci = s[conv2d_stage].op.axis
151+
k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis
152+
s[conv2d_stage].reorder(x_bo, k_o, x_j, d_j, d_i, x_co, x_i, x_bi, x_ci, k_i)
153+
154+
k_o, _ = cfg['tile_ci'].apply(s, conv2d_stage, k_o)
155+
s[cdata].compute_at(s[conv2d_stage], k_o)
156+
s[ckernel].compute_at(s[conv2d_stage], k_o)
157+
158+
# Use VTA instructions
159+
s[cdata].pragma(s[cdata].op.axis[0], load_inp)
160+
s[ckernel].pragma(s[ckernel].op.axis[0], load_wgt)
161+
s[conv2d_stage].tensorize(x_bi, gemm)
162+
s[output].pragma(x_co1, store_out)
163+
return s
164+
165+
@autotvm.template
166+
def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype):
167+
data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN)
168+
kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN)
169+
bias_shape = (N//env.BATCH, CO//env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)
170+
171+
OH = (H + 2 * padding[0] - KH) // strides[0] + 1
172+
OW = (W + 2 * padding[1] - KW) // strides[1] + 1
173+
174+
data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
175+
bias = tvm.placeholder(bias_shape, name="kernel", dtype=env.acc_dtype)
176+
177+
w_pack_factor = 1 << (3 - env.LOG_WGT_WIDTH)
178+
kernel_shape_pack = kernel_shape[:-1] + (kernel_shape[-1] // w_pack_factor,)
179+
kernel_arg = tvm.placeholder(kernel_shape_pack, dtype="int8", name="kernel_arg")
180+
kernel = vta.reinterpret(kernel_arg, kernel_shape, dtype=env.wgt_dtype)
181+
182+
res_conv = vta.top.packed_conv2d(data, kernel, padding=padding, strides=strides)
183+
res = topi.right_shift(res_conv, 8)
184+
res = topi.add(res, bias)
185+
res = my_clip(res, 0, 127)
186+
res = topi.cast(res, "int8")
187+
188+
cfg = autotvm.get_config()
189+
s = schedule_packed_conv2d(cfg, [res])
190+
191+
cfg.add_flop(2 * N * CI * OH * OW * CO * KH * KW)
192+
return s, [data, kernel_arg, bias, res]
193+
194+
if __name__ == '__main__':
195+
N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype = \
196+
1, 64, 56, 56, 64, 3, 3, (1, 1), (1, 1), 'int8', 'int32'
197+
198+
task = autotvm.task.create(conv2d, args=(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype),
199+
target='ext_dev', target_host=env.target_host)
200+
print(task.config_space)
201+
202+
# logging config (for printing tuning log to the screen)
203+
logging.getLogger('autotvm').setLevel(logging.DEBUG)
204+
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
205+
206+
measure_option = autotvm.measure_option(
207+
builder=autotvm.LocalBuilder(build_func=vta_build_func),
208+
runner=autotvm.RPCRunner(
209+
'ultra96', 'fleet', 9190))
210+
211+
tuner = autotvm.tuner.RandomTuner(task)
212+
tuner.tune(n_trial=len(task.config_space),
213+
measure_option=measure_option,
214+
callbacks=[autotvm.callback.log_to_file('conv2d.log')])
215+
216+
print(tuner.best_config)
217+

0 commit comments

Comments
 (0)