Skip to content

Commit a155c1f

Browse files
minminsunmerrymercy
authored andcommitted
Ansor Relay Integration (without layout rewrite) (apache#22)
* relay integration
1 parent 3a24e49 commit a155c1f

File tree

12 files changed

+1579
-2
lines changed

12 files changed

+1579
-2
lines changed

python/tvm/ansor/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,20 @@
2828
from . import task_scheduler
2929

3030
# Shortcut
31-
from .compute_dag import ComputeDAG
31+
from .compute_dag import ComputeDAG, LayoutRewriteLevel, gen_schedule
3232
from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams, \
3333
PreLoadMeasuredStates, PreAddCustomRule
3434
from .auto_schedule import auto_schedule
3535
from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext
3636
from .cost_model import RandomModel
3737
from .cost_model.xgb_model import XGBModel
38-
from .serialization import LogToFile, LogReader, best_measure_pair_in_file, write_measure_records_to_file
38+
from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \
39+
load_from_file, write_measure_records_to_file
3940
from .workload_registry import register_auto_scheduler_workload_func, \
4041
workload_key_to_dag, make_workload_key_func
4142
from .task_scheduler import TaskScheduler, SimpleTaskScheduler
43+
from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest as apply_history_best, \
44+
FallbackContext, clear_fallback_cache, ApplyGraphBest, BlockingEmptyContext
45+
from .topi_integration import register_topi_schedule, TaskExtractEnv
46+
from .relay_integration import extract_from_program, extract_from_multiple_program, \
47+
finish_layout_rewrite

python/tvm/ansor/compute_dag.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import tvm._ffi
2121
from tvm.runtime import Object
22+
from tvm import te
2223
from .loop_state import State
2324
from . import _ffi_api
2425

@@ -88,3 +89,13 @@ def infer_bound_from_state(self, state):
8889
state : StateObject
8990
"""
9091
return _ffi_api.ComputeDAGInferBoundFromState(self, state)
92+
93+
def gen_schedule(state, bufs):
94+
if not state or not state.complete:
95+
return te.create_schedule([x.op for x in bufs])
96+
else:
97+
dag = ComputeDAG(bufs)
98+
# only update compute body, layout_rewrite_level = LayoutRewriteLevel.COMPUTE_REWRITE,
99+
# since kernel layout has already been rewritten in relay pass
100+
schedule, _ = dag.apply_steps_from_state(state, layout_rewrite_level=LayoutRewriteLevel.COMPUTE_REWRITE)
101+
return schedule

0 commit comments

Comments
 (0)