Skip to content

Commit 558ba99

Browse files
authored
[MetaSchedule] Tuning Script Upgrade (#11797)
* Support uint8. * Modify tuning functions. * Follow legacy setting, use int32 for uint8. * Add vm support. * Fix vm usage. * Use vm in rpc run module. * Fix lint & stuff. * Fix backend. * Fix ftimer. * Fix lint. * Limit backend choice. * Add try catch. * Display name in rpc try catch. * Support ahb from tune_relay. * Modify scripts. * Fix typo. * Minor fix. * Fix try catch & func name. * Fix utils. * Move utils to tune_utils. * Fix tune_utils.
1 parent 898946f commit 558ba99

File tree

11 files changed

+448
-363
lines changed

11 files changed

+448
-363
lines changed

python/tvm/auto_scheduler/testing/tune_onnx.py

Lines changed: 59 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,18 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=missing-docstring
18+
from distutils.util import strtobool
1819
import argparse
1920
import json
2021
import os
21-
22-
from distutils.util import strtobool
23-
import numpy as np # type: ignore
2422
import onnx # type: ignore
23+
2524
import tvm
2625
from tvm import auto_scheduler
2726
from tvm import meta_schedule as ms
2827
from tvm import relay
2928
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
29+
from tvm.meta_schedule.testing.tune_utils import generate_input_data, create_timer
3030
from tvm.meta_schedule.utils import cpu_count
3131
from tvm.relay.frontend import from_onnx
3232
from tvm.support import describe
@@ -96,17 +96,23 @@ def _parse_args():
9696
default=100,
9797
)
9898
args.add_argument(
99-
"--cpu-flush",
99+
"--adaptive-training",
100100
type=lambda x: bool(strtobool(x)),
101-
required=True,
102101
help="example: True / False",
102+
default=True,
103103
)
104104
args.add_argument(
105-
"--adaptive-training",
105+
"--cpu-flush",
106106
type=lambda x: bool(strtobool(x)),
107-
required=False,
108107
help="example: True / False",
109-
default=True,
108+
required=True,
109+
)
110+
args.add_argument(
111+
"--backend",
112+
type=str,
113+
choices=["graph", "vm"],
114+
help="example: graph / vm",
115+
required=True,
110116
)
111117
parsed = args.parse_args()
112118
parsed.target = tvm.target.Target(parsed.target)
@@ -135,6 +141,7 @@ def main():
135141
repeat=ARGS.repeat,
136142
min_repeat_ms=ARGS.min_repeat_ms,
137143
enable_cpu_cache_flush=ARGS.cpu_flush,
144+
timeout=ARGS.rpc_config.session_timeout_sec,
138145
)
139146

140147
if ARGS.target.kind.name == "llvm":
@@ -163,102 +170,63 @@ def main():
163170
onnx_model = onnx.load(ARGS.onnx_path)
164171
shape_dict = {}
165172
for item in ARGS.input_shape:
166-
print(f" input_name: {item['name']}")
173+
print(f" input_name : {item['name']}")
167174
print(f" input_shape: {item['shape']}")
168175
print(f" input_dtype: {item['dtype']}")
169176
shape_dict[item["name"]] = item["shape"]
170177
mod, params = from_onnx(onnx_model, shape_dict, freeze_params=True)
171-
tasks, task_weights = auto_scheduler.extract_tasks(
172-
mod["main"],
173-
params,
174-
target=ARGS.target,
175-
hardware_params=hardware_params,
176-
)
177-
for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
178-
print(f"==== Task {idx}: {task.desc} (weight {task_weight} key: {task.workload_key}) =====")
179-
print(task.compute_dag)
180-
181-
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
182-
tuner.tune(
183-
auto_scheduler.TuningOptions(
184-
num_measure_trials=ARGS.num_trials,
185-
runner=runner,
186-
measure_callbacks=[
187-
auto_scheduler.RecordToFile(log_file),
188-
],
189-
),
190-
adaptive_training=ARGS.adaptive_training,
191-
)
192-
193-
with auto_scheduler.ApplyHistoryBest(log_file):
194-
with tvm.transform.PassContext(
195-
opt_level=3,
196-
config={"relay.backend.use_auto_scheduler": True},
197-
):
198-
lib = relay.build(
199-
mod,
200-
target=ARGS.target,
201-
params=params,
178+
input_data = {
179+
item["name"]: generate_input_data(item["shape"], item["dtype"]) for item in ARGS.input_shape
180+
}
181+
182+
with ms.Profiler() as profiler:
183+
tasks, task_weights = auto_scheduler.extract_tasks(
184+
mod["main"],
185+
params,
186+
target=ARGS.target,
187+
hardware_params=hardware_params,
188+
)
189+
for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
190+
print(
191+
f"==== Task {idx}: {task.desc} "
192+
f"(weight {task_weight} key: {task.workload_key}) ====="
202193
)
203-
graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
204-
input_data = {}
205-
for item in ARGS.input_shape:
206-
input_name, input_shape, input_dtype = item["name"], item["shape"], item["dtype"]
207-
if input_dtype.startswith("float"):
208-
input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype)
209-
else:
210-
input_data[input_name] = np.random.randint(
211-
low=0, high=10000, size=input_shape, dtype=input_dtype
194+
print(task.compute_dag)
195+
196+
if ARGS.num_trials > 0:
197+
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
198+
tuner.tune(
199+
auto_scheduler.TuningOptions(
200+
num_measure_trials=ARGS.num_trials,
201+
runner=runner,
202+
measure_callbacks=[
203+
auto_scheduler.RecordToFile(log_file),
204+
],
205+
),
206+
adaptive_training=ARGS.adaptive_training,
212207
)
213208

214-
def f_timer(rt_mod, dev, input_data):
215-
# pylint: disable=import-outside-toplevel
216-
from tvm.contrib.graph_executor import GraphModule
217-
218-
# pylint: enable=import-outside-toplevel
219-
220-
mod = GraphModule(rt_mod["default"](dev))
221-
for input_name, input_value in input_data.items():
222-
mod.set_input(input_name, input_value)
223-
ftimer = mod.module.time_evaluator(
224-
"run",
225-
dev,
226-
min_repeat_ms=500,
227-
repeat=3,
228-
)
229-
results = list(np.array(ftimer().results) * 1000.0) # type: ignore
230-
print("Running time in time_evaluator: ", results)
209+
relay_build = {"graph": relay.build, "vm": relay.vm.compile}[ARGS.backend]
210+
with auto_scheduler.ApplyHistoryBest(log_file):
211+
with tvm.transform.PassContext(
212+
opt_level=3,
213+
config={"relay.backend.use_auto_scheduler": True},
214+
):
215+
lib = relay_build(
216+
mod,
217+
target=ARGS.target,
218+
params=params,
219+
)
220+
print("Tuning Time:")
221+
print(profiler.table())
231222

232223
run_module_via_rpc(
233224
rpc_config=ARGS.rpc_config,
234225
lib=lib,
235226
dev_type=ARGS.target.kind.name,
236227
args=input_data,
237-
continuation=f_timer,
238-
)
239-
240-
def f_per_layer(rt_mod, dev, input_data):
241-
# pylint: disable=import-outside-toplevel
242-
from tvm.contrib.debugger.debug_executor import create
243-
244-
# pylint: enable=import-outside-toplevel
245-
mod = create(graph, rt_mod, dev)
246-
for input_name, input_value in input_data.items():
247-
mod.set_input(input_name, input_value)
248-
graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]]
249-
graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000)
250-
print("|graph_nodes| = ", len(graph_nodes))
251-
print("|graph_time| = ", len(graph_time))
252-
graph_nodes_time = {k: float(v) for k, v in zip(graph_nodes, graph_time)}
253-
for k, v in graph_nodes_time.items():
254-
print(f"{k} : {v:.3f}")
255-
256-
run_module_via_rpc(
257-
rpc_config=ARGS.rpc_config,
258-
lib=rt_mod,
259-
dev_type=ARGS.target.kind.name,
260-
args=input_data,
261-
continuation=f_per_layer,
228+
continuation=create_timer(ARGS.backend),
229+
backend=ARGS.backend,
262230
)
263231

264232

0 commit comments

Comments
 (0)