55"""
66import warnings
77import logging
8+ import sys
89
910
1011from ... import tensor , placeholder , create_schedule , target as _target
@@ -49,9 +50,9 @@ def deserialize_args(args):
4950# Task extractor for nnvm graph
5051class TaskExtractEnv :
5152 """Global environment for extracting tuning tasks from nnvm graph"""
52- current = None
53+ registered = False
5354
54- def __init__ (self ):
55+ def __init__ (self , wanted_symbols ):
5556 import topi
5657 import nnvm
5758
@@ -83,46 +84,62 @@ def __init__(self):
8384 topi .nn .dense : [topi .generic .schedule_dense ],
8485 }
8586
86- self ._register_tracing ()
87+ # support reflection for tracing
88+ self .func_to_reflection = {
89+ topi .nn .conv2d : lambda x : setattr (topi .nn , 'conv2d' , x ),
90+ topi .nn .depthwise_conv2d_nchw : lambda x : setattr (topi .nn , 'depthwise_conv2d_nchw' , x ),
91+ topi .nn .conv2d_transpose_nchw : lambda x : setattr (topi .nn , 'conv2d_transpose_nchw' , x ),
92+ topi .nn .dense : lambda x : setattr (topi .nn , 'dense' , x ),
93+ }
94+
95+
96+ self .wanted_topi_funcs = []
97+ for sym_name in wanted_symbols :
98+ if sym_name in self .symbol2topi :
99+ self .wanted_topi_funcs .extend (self .symbol2topi [sym_name ])
100+ else :
101+ warnings .warn ("Symbol %s is not tunable, ignored" % sym_name )
102+
87103 self ._register_topi_task ()
88104 self .task_collection = []
89- self .wanted_topi_funcs = list ( self . topi_to_task . keys ())
105+ self .modified_funcs = []
90106
91- def _register_tracing (self ):
92- """Register tracing function to track the topi function call"""
93- # register topi compute for "tracing" target
94- for topi_compute in self .topi_to_task :
107+ def __enter__ (self ):
108+ self .task_collection = []
109+ self .modified_funcs = []
110+
111+ for topi_compute in self .wanted_topi_funcs :
95112 def _local_scope (compute_func ):
96113 """start a scope to hold the local function in for loop"""
97114
98- @compute_func .register ("tracing" , )
99- def _tracing_topi_compute (* args , ** kwargs ):
100- assert not kwargs , "Do not support extracting tuning tasks when" \
101- "kwargs is used in TOPI function call." \
115+ def _tracing_wrapper (* args , ** kwargs ):
116+ assert not kwargs , "Do not support extracting tuning tasks when " \
117+ "kwargs is used in TOPI function call. " \
102118 "Please modify it to use only positional args."
103119
104- if compute_func in self .wanted_topi_funcs : # record this call
105- key = (self .topi_to_task [compute_func ], serialize_args (args ))
106- if key not in self .task_collection :
107- self .task_collection .append (key )
120+ key = (self .topi_to_task [compute_func ], serialize_args (args ))
121+ if key not in self .task_collection :
122+ self .task_collection .append (key )
123+
124+ return compute_func (* args , ** kwargs )
125+
126+ self .func_to_reflection [topi_compute ](_tracing_wrapper )
127+ self .modified_funcs .append (topi_compute )
108128
109- return compute_func .fdefault (* args )
110129 _local_scope (topi_compute )
111130
112- # register topi schedule for "tracing" target
113- for topi_compute in self .topi_to_task :
114- for topi_schedule in self .topi_to_schedule [topi_compute ]:
115- def _local_scope_ (schedule_func ):
116- """start a scope to hold the local function in for loop"""
131+ return self
117132
118- @schedule_func .register ("tracing" , )
119- def _tracing_topi_compute (outs ):
120- outs = [outs ] if isinstance (outs , tensor .Tensor ) else outs
121- return create_schedule ([x .op for x in outs ])
122- _local_scope_ (topi_schedule )
133+ def __exit__ (self , exc_type , exc_val , exc_tb ):
134+ # revert modification
135+ for func in self .modified_funcs :
136+ self .func_to_reflection [func ](func )
123137
124138 def _register_topi_task (self ):
125139 """register tuning wrapper for topi function"""
140+ if TaskExtractEnv .registered :
141+ return
142+ TaskExtractEnv .registered = True
126143 import topi
127144
128145 # Tuning wrapper for topi functions
@@ -175,17 +192,6 @@ def _topi_nn_dense(*args, **kwargs):
175192 return s , [data , weight , bias , C ]
176193 return s , [data , weight , C ]
177194
178- def reset (self , wanted_topi_funcs ):
179- """Reset task collections
180-
181- Parameters
182- ----------
183- wanted_topi_funcs: List of function
184- The topi function to be extracted
185- """
186- self .task_collection = []
187- self .wanted_topi_funcs = wanted_topi_funcs
188-
189195 def get_tasks (self ):
190196 """Get collected tasks
191197
@@ -196,25 +202,11 @@ def get_tasks(self):
196202 """
197203 return self .task_collection
198204
199- @staticmethod
200- def get ():
201- """Get the single instance of TaskExtractEnv
202-
203- Returns
204- -------
205- env: TaskExtractEnv
206- The single instance of TaskExtractEnv
207- """
208- if not TaskExtractEnv .current :
209- TaskExtractEnv .current = TaskExtractEnv ()
210- return TaskExtractEnv .current
211-
212205
213206def extract_from_graph (graph , shape , dtype , target , symbols , target_host = None ):
214207 """ Extract tuning tasks from a nnvm graph.
215208
216- This function collects tuning tasks by building the graph
217- with a "tracing" target and tracing all the calls to topi.
209+ This function collects tuning tasks by building the graph and trace all the calls to topi.
218210
219211 Parameters
220212 ----------
@@ -237,97 +229,34 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
237229 collected tasks
238230 """
239231 import nnvm .compiler
232+ import topi
240233
241- env = TaskExtractEnv . get ( )
234+ env = TaskExtractEnv ( symbols )
242235
243- topi_funcs = []
244- for sym_name in symbols :
245- if sym_name in env .symbol2topi :
246- topi_funcs .extend (env .symbol2topi [sym_name ])
247- else :
248- warnings .warn ("Symbol %s is not tunable, ignored" % sym_name )
236+ with env :
237+ # disable logger temporarily
238+ old_state = logger .disabled
239+ logger .disabled = True
249240
250- # run compiler to collect all TOPI calls during compilation
251- env .reset (topi_funcs )
241+ # run compiler to collect all TOPI calls during compilation
242+ nnvm .compiler .engine .clear_cache ()
243+ nnvm .compiler .build (graph , target = target , shape = shape , dtype = dtype )
244+ nnvm .compiler .engine .clear_cache ()
252245
253- # disable logger temporarily
254- old_state = logger .disabled
255- logger .disabled = True
256-
257- # use a "tracing" target to do a fake compile for collecting topi calls
258- tracing_target = _target .create ("llvm -device=tracing" )
259- nnvm .compiler .engine .clear_cache ()
260- nnvm .compiler .build (graph , target = tracing_target , shape = shape , dtype = dtype )
261-
262- logger .disabled = old_state
246+ logger .disabled = old_state
263247
264248 # create tasks for target
265249 tasks = []
266250 for task_name , args in env .get_tasks ():
267- tasks .append (create (task_name , args ,
268- target = target , target_host = target_host ,
269- template_key = 'direct' ))
251+ try :
252+ tsk = create (task_name , args ,
253+ target = target , target_host = target_host ,
254+ template_key = 'direct' )
255+ tasks .append (tsk )
256+ except topi .InvalidShapeError :
257+ print ("shape error" )
270258
271259 return tasks
272260
273-
274- def extract_from_multiple_graph (graphs , shapes , dtypes , target , symbols , target_host = None ):
275- """ Extract tuning tasks from multiple nnvm graphs.
276-
277- This function is the multiple graph version of extract_from_graph
278-
279- Parameters
280- ----------
281- graphs : List of Graph
282- The list of graphs to tune
283- shapes : List of dict of str to tuple
284- The input shape to the graph
285- dtypes : List of str or dict of str to str
286- The input types to the graph
287- target: tvm.target.Target
288- The compilation target
289- symbols : Array of nnvm.symbol
290- Array of nnvm symbols want to be tuned
291- target_host: tvm.target.Target
292- The host compilation target
293-
294- Returns
295- -------
296- task: Array of autotvm.task.Task
297- collected tasks
298- """
299- import nnvm .compiler
300-
301- env = TaskExtractEnv .get ()
302-
303- topi_funcs = []
304- for sym_name in symbols :
305- if sym_name in env .symbol2topi :
306- topi_funcs .extend (env .symbol2topi [sym_name ])
307- else :
308- warnings .warn ("Symbol %s is not tunable, ignored" % sym_name )
309-
310- # run compiler to collect all TOPI calls during compilation
311- env .reset (topi_funcs )
312-
313- # disable logger temporarily
314- old_state = logger .disabled
315- logger .disabled = True
316-
317- # use a "tracing" target to do a fake compile for collecting topi calls
318- tracing_target = _target .create ("llvm -device=tracing" )
319-
320- nnvm .compiler .engine .clear_cache ()
321- for graph , shape , dtype in zip (graphs , shapes , dtypes ):
322- nnvm .compiler .build (graph , target = tracing_target , shape = shape , dtype = dtype )
323-
324- logger .disabled = old_state
325-
326- # create tasks for target
327- tasks = []
328- for task_name , args in env .get_tasks ():
329- tasks .append (create (task_name , args ,
330- target = target , target_host = target_host ,
331- template_key = 'direct' ))
332-
333- return tasks
261+ def extract_from_multiple_graph (graph , shape , dtype , target , symbols , target_host = None ):
262+ pass
0 commit comments