@@ -54,7 +54,8 @@ def _lower(func,
5454 return grc .codegen (mod ["main" ])
5555
5656
57- def extract_from_program (func , params , ops , target , target_host = None ):
57+ def extract_from_program (func , params , ops , target , target_host = None ,
58+ template_keys = None ):
5859 """ Extract tuning tasks from a relay program.
5960
6061 This function is the single program version of extract_from_multiple_program.
@@ -71,16 +72,21 @@ def extract_from_program(func, params, ops, target, target_host=None):
7172 The compilation target
7273 target_host: tvm.target.Target
7374 The host compilation target
75+ template_keys: dict of topi op to str
76+ The tuning template keys map for schedules, default to None.
77+ Example: {topi.nn.conv2d: 'direct'}
7478
7579 Returns
7680 -------
7781 task: Array of autotvm.task.Task
7882 collected tasks
7983 """
80- return extract_from_multiple_program ([func ], [params ], ops , target , target_host )
84+ return extract_from_multiple_program ([func ], [params ], ops , target , target_host ,
85+ template_keys = template_keys )
8186
8287
83- def extract_from_multiple_program (funcs , params , ops , target , target_host = None ):
88+ def extract_from_multiple_program (funcs , params , ops , target , target_host = None ,
89+ template_keys = None ):
8490 """ Extract tuning tasks from multiple relay programs.
8591
8692 This function collects tuning tasks by building a list of programs
@@ -98,6 +104,9 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
98104 The compilation target
99105 target_host: tvm.target.Target
100106 The host compilation target
107+ template_keys: dict of topi op to str
108+ The tuning template keys map for schedules, default to None.
109+ Example: {topi.nn.conv2d: 'direct'}
101110
102111 Returns
103112 -------
@@ -146,15 +155,26 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
146155
147156 logger .disabled = old_state
148157
158+ # convert *topi op to template key* map to *task name to template key* map
159+ task_name_to_keys = {}
160+ if template_keys is not None :
161+ for op in template_keys .keys ():
162+ if op in env .topi_to_task :
163+ task_name_to_keys [env .topi_to_task [op ]] = template_keys [op ]
164+ else :
165+ logger .warning ("Invalid template key, fallback to direct" )
166+ task_name_to_keys [env .topi_to_task [op ]] = 'direct'
167+
149168 # create tasks for target
150169 tasks = []
151170 for task_name , args in env .get_tasks ():
152171 try :
172+ key = task_name_to_keys [task_name ] if task_name in task_name_to_keys else 'direct'
153173 tsk = create (task_name , args ,
154174 target = target , target_host = target_host ,
155- template_key = 'direct' )
175+ template_key = key )
156176 tasks .append (tsk )
157177 except topi .InvalidShapeError :
158- print ( "[Warning] Invalid shape during AutoTVM task creation" )
178+ logger . warning ( " Invalid shape during AutoTVM task creation" )
159179
160180 return tasks
0 commit comments