Skip to content

Commit ccde31f

Browse files
zhenhuaw-mevinx13
authored andcommitted
AutoTVM: selecting tuning templates when extracting task (#4338)
* AutoTVM: selecting tuning templates when extracting task Make the procedure of trying new templates easier. Test: tests/python/relay/test_autotvm_task_extraction.py * Use dict to match key for topi ops * fix lint issue * be more pythonic :)
1 parent 0a9f7e9 commit ccde31f

File tree

2 files changed

+71
-5
lines changed

2 files changed

+71
-5
lines changed

python/tvm/autotvm/task/relay_integration.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def _lower(func,
5454
return grc.codegen(mod["main"])
5555

5656

57-
def extract_from_program(func, params, ops, target, target_host=None):
57+
def extract_from_program(func, params, ops, target, target_host=None,
58+
template_keys=None):
5859
""" Extract tuning tasks from a relay program.
5960
6061
This function is the single program version of extract_from_multiple_program.
@@ -71,16 +72,21 @@ def extract_from_program(func, params, ops, target, target_host=None):
7172
The compilation target
7273
target_host: tvm.target.Target
7374
The host compilation target
75+
template_keys: dict of topi op to str
76+
The tuning template keys map for schedules, default to None.
77+
Example: {topi.nn.conv2d: 'direct'}
7478
7579
Returns
7680
-------
7781
task: Array of autotvm.task.Task
7882
collected tasks
7983
"""
80-
return extract_from_multiple_program([func], [params], ops, target, target_host)
84+
return extract_from_multiple_program([func], [params], ops, target, target_host,
85+
template_keys=template_keys)
8186

8287

83-
def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
88+
def extract_from_multiple_program(funcs, params, ops, target, target_host=None,
89+
template_keys=None):
8490
""" Extract tuning tasks from multiple relay programs.
8591
8692
This function collects tuning tasks by building a list of programs
@@ -98,6 +104,9 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
98104
The compilation target
99105
target_host: tvm.target.Target
100106
The host compilation target
107+
template_keys: dict of topi op to str
108+
The tuning template keys map for schedules, default to None.
109+
Example: {topi.nn.conv2d: 'direct'}
101110
102111
Returns
103112
-------
@@ -146,15 +155,26 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
146155

147156
logger.disabled = old_state
148157

158+
# convert *topi op to template key* map to *task name to template key* map
159+
task_name_to_keys = {}
160+
if template_keys is not None:
161+
for op in template_keys.keys():
162+
if op in env.topi_to_task:
163+
task_name_to_keys[env.topi_to_task[op]] = template_keys[op]
164+
else:
165+
logger.warning("Invalid template key, fallback to direct")
166+
task_name_to_keys[env.topi_to_task[op]] = 'direct'
167+
149168
# create tasks for target
150169
tasks = []
151170
for task_name, args in env.get_tasks():
152171
try:
172+
key = task_name_to_keys[task_name] if task_name in task_name_to_keys else 'direct'
153173
tsk = create(task_name, args,
154174
target=target, target_host=target_host,
155-
template_key='direct')
175+
template_key=key)
156176
tasks.append(tsk)
157177
except topi.InvalidShapeError:
158-
print("[Warning] Invalid shape during AutoTVM task creation")
178+
logger.warning("Invalid shape during AutoTVM task creation")
159179

160180
return tasks

tests/python/relay/test_autotvm_task_extraction.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,51 @@ def test_task_extraction():
7979
ops=(relay.op.nn.conv2d,))
8080
assert len(tasks) == 31
8181

82+
def test_template_key_provided():
83+
"""test task extraction using non-'direct' template_key"""
84+
target = 'llvm'
85+
86+
import topi
87+
template_keys = {
88+
# topi.nn.conv2d - is left blank to test fallback logic
89+
topi.nn.dense: 'direct_nopack',
90+
topi.nn.depthwise_conv2d_nchw: 'direct',
91+
}
92+
93+
mod, params, _ = get_network('mobilenet', batch_size=1)
94+
tasks = autotvm.task.extract_from_program(mod['main'], target=target,
95+
params=params,
96+
ops=(relay.op.nn.conv2d, relay.op.nn.dense),
97+
template_keys=template_keys)
98+
for task in tasks:
99+
if 'dense' in task.name:
100+
assert task.config_space.template_key == 'direct_nopack'
101+
else:
102+
assert task.config_space.template_key == 'direct'
103+
104+
def test_template_key_empty():
105+
"""test task extraction using empty template_key"""
106+
target = 'llvm'
107+
mod, params, _ = get_network('mobilenet', batch_size=1)
108+
tasks = autotvm.task.extract_from_program(mod['main'], target=target,
109+
params=params,
110+
ops=(relay.op.nn.conv2d, relay.op.nn.dense),
111+
template_keys=None)
112+
for task in tasks:
113+
assert task.config_space.template_key == 'direct'
114+
115+
def test_template_key_default():
116+
"""test task extraction without template_key"""
117+
target = 'llvm'
118+
mod, params, _ = get_network('mobilenet', batch_size=1)
119+
tasks = autotvm.task.extract_from_program(mod['main'], target=target,
120+
params=params,
121+
ops=(relay.op.nn.conv2d, relay.op.nn.dense))
122+
for task in tasks:
123+
assert task.config_space.template_key == 'direct'
124+
82125
if __name__ == '__main__':
83126
test_task_extraction()
127+
test_template_key_provided()
128+
test_template_key_empty()
129+
test_template_key_default()

0 commit comments

Comments
 (0)