55"""
66import warnings
77import logging
8+ import sys
89
910
1011from ... import target as _target
1819def extract_from_graph (graph , shape , dtype , target , symbols , target_host = None ):
1920 """ Extract tuning tasks from a nnvm graph.
2021
21- This function collects tuning tasks by building the graph
22- with a "tracing" target and tracing all the calls to topi.
22+ This function collects tuning tasks by building the graph and trace all the calls to topi.
2323
2424 Parameters
2525 ----------
@@ -45,7 +45,7 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
4545 import nnvm
4646 import topi
4747
48- env = TaskExtractEnv . get ( )
48+ env = TaskExtractEnv ( symbols )
4949
5050 #NOTE: To add more symbols, you only need to change the following lists
5151 #nnvm symbol -> topi compute
@@ -63,26 +63,23 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
6363 else :
6464 warnings .warn ("Symbol %s is not tunable, ignored" % sym_name )
6565
66- # run compiler to collect all TOPI calls during compilation
67- env .reset (topi_funcs )
66+ # run compiler to collect all TOPI calls during compilation
67+ nnvm .compiler .engine .clear_cache ()
68+ nnvm .compiler .build (graph , target = target , shape = shape , dtype = dtype )
69+ nnvm .compiler .engine .clear_cache ()
6870
69- # disable logger temporarily
70- old_state = logger .disabled
71- logger .disabled = True
72-
73- # use a "tracing" target to do a fake compile for collecting topi calls
74- tracing_target = _target .create ("llvm -device=tracing" )
75- nnvm .compiler .engine .clear_cache ()
76- nnvm .compiler .build (graph , target = tracing_target , shape = shape , dtype = dtype )
77-
78- logger .disabled = old_state
71+ logger .disabled = old_state
7972
8073 # create tasks for target
8174 tasks = []
8275 for task_name , args in env .get_tasks ():
83- tasks .append (create (task_name , args ,
84- target = target , target_host = target_host ,
85- template_key = 'direct' ))
76+ try :
77+ tsk = create (task_name , args ,
78+ target = target , target_host = target_host ,
79+ template_key = 'direct' )
80+ tasks .append (tsk )
81+ except topi .InvalidShapeError :
82+ print ("shape error" )
8683
8784 return tasks
8885
0 commit comments