Skip to content

Commit 8ebdf6e

Browse files
authored
[MetaSchedule] Misc update for e2e workloads (#10776)
1 parent 3918717 commit 8ebdf6e

File tree

9 files changed

+431
-16
lines changed

9 files changed

+431
-16
lines changed

python/tvm/meta_schedule/integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919

2020
import numpy as np # type: ignore
2121
import tvm.runtime.ndarray as nd
22-
23-
from tvm._ffi import register_object, get_global_func
22+
from tvm._ffi import get_global_func, register_object
2423
from tvm.ir import IRModule, transform
2524
from tvm.relay import Any
2625
from tvm.relay import Function as RelayFunc
@@ -29,6 +28,7 @@
2928

3029
from . import _ffi_api
3130
from .database import Database
31+
from .utils import autotvm_silencer
3232

3333

3434
@register_object("meta_schedule.ExtractedTask")
@@ -234,7 +234,7 @@ def extract_task_from_relay(
234234
if not isinstance(target, Target):
235235
target = Target(target)
236236

237-
with target, transform.PassContext(
237+
with autotvm_silencer(), target, transform.PassContext(
238238
opt_level=opt_level,
239239
config=pass_config,
240240
disabled_pass=disabled_pass,

python/tvm/meta_schedule/testing/custom_builder_runner.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
"""Customized builder and runner methods"""
1818
# pylint: disable=import-outside-toplevel
1919

20-
from typing import TYPE_CHECKING, Dict, List
20+
from typing import TYPE_CHECKING, Callable, Dict, List
2121

2222
if TYPE_CHECKING:
23+
import numpy as np # type: ignore
2324
from tvm.ir import IRModule
24-
from tvm.meta_schedule.runner import EvaluatorConfig
25+
from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
2526
from tvm.runtime import Device, Module, NDArray
2627
from tvm.target import Target
2728

@@ -138,3 +139,32 @@ def run_with_graph_executor(
138139
repeated_costs.append(profile_result.results)
139140
costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
140141
return costs
142+
143+
144+
def run_module_via_rpc(
145+
rpc_config: "RPCConfig",
146+
lib: "Module",
147+
dev_type: str,
148+
args: List["np.ndarray"],
149+
continuation: Callable,
150+
):
151+
"""Execute a tvm.runtime.Module on RPC remote"""
152+
# pylint: disable=import-outside-toplevel
153+
import os
154+
import tempfile
155+
156+
from tvm.contrib.tar import tar
157+
from tvm.runtime import ndarray
158+
159+
# pylint: enable=import-outside-toplevel
160+
161+
with tempfile.TemporaryDirectory() as tmp_dir:
162+
filename = os.path.join(tmp_dir, "tvm_tmp_mod." + tar.output_format)
163+
lib.export_library(filename, tar)
164+
session = rpc_config.connect_server()
165+
session.upload(filename)
166+
_, filename = os.path.split(filename)
167+
rt_mod = session.load_module(filename)
168+
dev = session.device(dev_type=dev_type, dev_id=0)
169+
args = [ndarray.array(arg, dev) for arg in args]
170+
return continuation(rt_mod, dev, *args)

python/tvm/meta_schedule/testing/relay_workload.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
"""Workloads in Relay IR"""
1818
# pylint: disable=import-outside-toplevel
19+
import logging
1920
import multiprocessing
2021
import os
2122
import pickle
@@ -29,6 +30,8 @@
2930
from tvm.runtime import NDArray, load_param_dict, save_param_dict
3031
from tvm.target import Target
3132

33+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
34+
3235

3336
def _get_network(
3437
args: Tuple[str, List[int]]
@@ -170,7 +173,7 @@ def _load_cache(cache_dir: Optional[str], filename: str) -> Optional[List[Any]]:
170173
path = os.path.join(os.path.expanduser(cache_dir), filename)
171174
if not os.path.exists(path):
172175
return None
173-
print(f"Load from cache: {path}")
176+
logger.info("Loaded from cached: %s", path)
174177
with open(path, "rb") as i_f:
175178
return pickle.load(i_f)
176179

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-docstring
18+
import argparse
19+
import json
20+
import os
21+
22+
import numpy as np # type: ignore
23+
import tvm
24+
from tvm import auto_scheduler
25+
from tvm import meta_schedule as ms
26+
from tvm import relay
27+
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
28+
from tvm.meta_schedule.testing.relay_workload import get_network
29+
30+
31+
def _parse_args():
32+
args = argparse.ArgumentParser()
33+
args.add_argument(
34+
"--workload",
35+
type=str,
36+
required=True,
37+
)
38+
args.add_argument(
39+
"--input-shape",
40+
type=str,
41+
required=True,
42+
)
43+
args.add_argument(
44+
"--target",
45+
type=str,
46+
required=True,
47+
)
48+
args.add_argument(
49+
"--num-trials",
50+
type=int,
51+
required=True,
52+
)
53+
args.add_argument(
54+
"--rpc-host",
55+
type=str,
56+
required=True,
57+
)
58+
args.add_argument(
59+
"--rpc-port",
60+
type=int,
61+
required=True,
62+
)
63+
args.add_argument(
64+
"--rpc-key",
65+
type=str,
66+
required=True,
67+
)
68+
args.add_argument(
69+
"--rpc-workers",
70+
type=int,
71+
required=True,
72+
)
73+
args.add_argument(
74+
"--log-dir",
75+
type=str,
76+
required=True,
77+
)
78+
args.add_argument(
79+
"--cache-dir",
80+
type=str,
81+
default=None,
82+
)
83+
parsed = args.parse_args()
84+
parsed.target = tvm.target.Target(parsed.target)
85+
parsed.input_shape = json.loads(parsed.input_shape)
86+
parsed.rpc_config = ms.runner.RPCConfig(
87+
tracker_host=parsed.rpc_host,
88+
tracker_port=parsed.rpc_port,
89+
tracker_key=parsed.rpc_key,
90+
session_timeout_sec=3600,
91+
)
92+
return parsed
93+
94+
95+
ARGS = _parse_args()
96+
97+
98+
def main():
99+
log_file = os.path.join(ARGS.log_dir, f"{ARGS.workload}.json")
100+
101+
runner = auto_scheduler.RPCRunner(
102+
key=ARGS.rpc_key,
103+
host=ARGS.rpc_host,
104+
port=ARGS.rpc_port,
105+
n_parallel=ARGS.rpc_workers,
106+
number=3,
107+
repeat=1,
108+
min_repeat_ms=100, # TODO
109+
enable_cpu_cache_flush=False, # TODO
110+
)
111+
112+
if ARGS.target.kind.name == "llvm":
113+
hardware_params = auto_scheduler.HardwareParams(
114+
num_cores=int(ARGS.target.attrs["num-cores"]),
115+
target=ARGS.target,
116+
)
117+
elif ARGS.target.kind.name == "cuda":
118+
hardware_params = auto_scheduler.HardwareParams(
119+
num_cores=-1,
120+
vector_unit_bytes=16,
121+
cache_line_bytes=64,
122+
max_shared_memory_per_block=int(ARGS.target.attrs["max_shared_memory_per_block"]),
123+
max_threads_per_block=int(ARGS.target.attrs["max_threads_per_block"]),
124+
# The value `max_local_memory_per_block` is not used in AutoScheduler,
125+
# but is required by the API.
126+
max_local_memory_per_block=12345678,
127+
max_vthread_extent=8,
128+
warp_size=32,
129+
)
130+
else:
131+
raise NotImplementedError(f"Unsupported target {ARGS.target}")
132+
mod, params, (input_name, input_shape, input_dtype) = get_network(
133+
ARGS.workload,
134+
ARGS.input_shape,
135+
cache_dir=ARGS.cache_dir,
136+
)
137+
print(f"Workload: {ARGS.workload}")
138+
print(f" input_name: {input_name}")
139+
print(f" input_shape: {input_shape}")
140+
print(f" input_dtype: {input_dtype}")
141+
tasks, task_weights = auto_scheduler.extract_tasks(
142+
mod["main"],
143+
params,
144+
target=ARGS.target,
145+
hardware_params=hardware_params,
146+
)
147+
for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
148+
print(f"==== Task {idx}: {task.desc} (weight {task_weight} key: {task.workload_key}) =====")
149+
print(task.compute_dag)
150+
151+
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
152+
tuner.tune(
153+
auto_scheduler.TuningOptions(
154+
num_measure_trials=ARGS.num_trials,
155+
runner=runner,
156+
measure_callbacks=[
157+
auto_scheduler.RecordToFile(log_file),
158+
],
159+
)
160+
)
161+
162+
with auto_scheduler.ApplyHistoryBest(log_file):
163+
with tvm.transform.PassContext(
164+
opt_level=3,
165+
config={"relay.backend.use_auto_scheduler": True},
166+
):
167+
lib = relay.build(
168+
mod,
169+
target=ARGS.target,
170+
params=params,
171+
)
172+
173+
if input_dtype.startswith("float"):
174+
input_data = np.random.uniform(size=input_shape).astype(input_dtype)
175+
else:
176+
input_data = np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype)
177+
178+
def f_timer(rt_mod, dev, input_data):
179+
# pylint: disable=import-outside-toplevel
180+
from tvm.contrib.graph_executor import GraphModule
181+
182+
# pylint: enable=import-outside-toplevel
183+
184+
mod = GraphModule(rt_mod["default"](dev))
185+
mod.set_input(input_name, input_data)
186+
ftimer = mod.module.time_evaluator(
187+
"run",
188+
dev,
189+
min_repeat_ms=500,
190+
repeat=3,
191+
)
192+
return list(np.array(ftimer().results))
193+
194+
results = run_module_via_rpc(
195+
rpc_config=ARGS.rpc_config,
196+
lib=lib,
197+
dev_type=ARGS.target.kind.name,
198+
args=[input_data],
199+
continuation=f_timer,
200+
)
201+
202+
print(results)
203+
204+
205+
if __name__ == "__main__":
206+
main()

0 commit comments

Comments
 (0)