|
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
17 | 17 | # pylint: disable=missing-docstring |
| 18 | +from distutils.util import strtobool |
18 | 19 | import argparse |
19 | 20 | import json |
20 | 21 | import os |
21 | | - |
22 | | -from distutils.util import strtobool |
23 | | -import numpy as np # type: ignore |
24 | 22 | import onnx # type: ignore |
| 23 | + |
25 | 24 | import tvm |
26 | 25 | from tvm import auto_scheduler |
27 | 26 | from tvm import meta_schedule as ms |
28 | 27 | from tvm import relay |
29 | 28 | from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc |
| 29 | +from tvm.meta_schedule.testing.tune_utils import generate_input_data, create_timer |
30 | 30 | from tvm.meta_schedule.utils import cpu_count |
31 | 31 | from tvm.relay.frontend import from_onnx |
32 | 32 | from tvm.support import describe |
@@ -96,17 +96,23 @@ def _parse_args(): |
96 | 96 | default=100, |
97 | 97 | ) |
98 | 98 | args.add_argument( |
99 | | - "--cpu-flush", |
| 99 | + "--adaptive-training", |
100 | 100 | type=lambda x: bool(strtobool(x)), |
101 | | - required=True, |
102 | 101 | help="example: True / False", |
| 102 | + default=True, |
103 | 103 | ) |
104 | 104 | args.add_argument( |
105 | | - "--adaptive-training", |
| 105 | + "--cpu-flush", |
106 | 106 | type=lambda x: bool(strtobool(x)), |
107 | | - required=False, |
108 | 107 | help="example: True / False", |
109 | | - default=True, |
| 108 | + required=True, |
| 109 | + ) |
| 110 | + args.add_argument( |
| 111 | + "--backend", |
| 112 | + type=str, |
| 113 | + choices=["graph", "vm"], |
| 114 | + help="example: graph / vm", |
| 115 | + required=True, |
110 | 116 | ) |
111 | 117 | parsed = args.parse_args() |
112 | 118 | parsed.target = tvm.target.Target(parsed.target) |
@@ -135,6 +141,7 @@ def main(): |
135 | 141 | repeat=ARGS.repeat, |
136 | 142 | min_repeat_ms=ARGS.min_repeat_ms, |
137 | 143 | enable_cpu_cache_flush=ARGS.cpu_flush, |
| 144 | + timeout=ARGS.rpc_config.session_timeout_sec, |
138 | 145 | ) |
139 | 146 |
|
140 | 147 | if ARGS.target.kind.name == "llvm": |
@@ -163,102 +170,63 @@ def main(): |
163 | 170 | onnx_model = onnx.load(ARGS.onnx_path) |
164 | 171 | shape_dict = {} |
165 | 172 | for item in ARGS.input_shape: |
166 | | - print(f" input_name: {item['name']}") |
| 173 | + print(f" input_name : {item['name']}") |
167 | 174 | print(f" input_shape: {item['shape']}") |
168 | 175 | print(f" input_dtype: {item['dtype']}") |
169 | 176 | shape_dict[item["name"]] = item["shape"] |
170 | 177 | mod, params = from_onnx(onnx_model, shape_dict, freeze_params=True) |
171 | | - tasks, task_weights = auto_scheduler.extract_tasks( |
172 | | - mod["main"], |
173 | | - params, |
174 | | - target=ARGS.target, |
175 | | - hardware_params=hardware_params, |
176 | | - ) |
177 | | - for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)): |
178 | | - print(f"==== Task {idx}: {task.desc} (weight {task_weight} key: {task.workload_key}) =====") |
179 | | - print(task.compute_dag) |
180 | | - |
181 | | - tuner = auto_scheduler.TaskScheduler(tasks, task_weights) |
182 | | - tuner.tune( |
183 | | - auto_scheduler.TuningOptions( |
184 | | - num_measure_trials=ARGS.num_trials, |
185 | | - runner=runner, |
186 | | - measure_callbacks=[ |
187 | | - auto_scheduler.RecordToFile(log_file), |
188 | | - ], |
189 | | - ), |
190 | | - adaptive_training=ARGS.adaptive_training, |
191 | | - ) |
192 | | - |
193 | | - with auto_scheduler.ApplyHistoryBest(log_file): |
194 | | - with tvm.transform.PassContext( |
195 | | - opt_level=3, |
196 | | - config={"relay.backend.use_auto_scheduler": True}, |
197 | | - ): |
198 | | - lib = relay.build( |
199 | | - mod, |
200 | | - target=ARGS.target, |
201 | | - params=params, |
| 178 | + input_data = { |
| 179 | + item["name"]: generate_input_data(item["shape"], item["dtype"]) for item in ARGS.input_shape |
| 180 | + } |
| 181 | + |
| 182 | + with ms.Profiler() as profiler: |
| 183 | + tasks, task_weights = auto_scheduler.extract_tasks( |
| 184 | + mod["main"], |
| 185 | + params, |
| 186 | + target=ARGS.target, |
| 187 | + hardware_params=hardware_params, |
| 188 | + ) |
| 189 | + for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)): |
| 190 | + print( |
| 191 | + f"==== Task {idx}: {task.desc} " |
| 192 | + f"(weight {task_weight} key: {task.workload_key}) =====" |
202 | 193 | ) |
203 | | - graph, rt_mod, params = lib.graph_json, lib.lib, lib.params |
204 | | - input_data = {} |
205 | | - for item in ARGS.input_shape: |
206 | | - input_name, input_shape, input_dtype = item["name"], item["shape"], item["dtype"] |
207 | | - if input_dtype.startswith("float"): |
208 | | - input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype) |
209 | | - else: |
210 | | - input_data[input_name] = np.random.randint( |
211 | | - low=0, high=10000, size=input_shape, dtype=input_dtype |
| 194 | + print(task.compute_dag) |
| 195 | + |
| 196 | + if ARGS.num_trials > 0: |
| 197 | + tuner = auto_scheduler.TaskScheduler(tasks, task_weights) |
| 198 | + tuner.tune( |
| 199 | + auto_scheduler.TuningOptions( |
| 200 | + num_measure_trials=ARGS.num_trials, |
| 201 | + runner=runner, |
| 202 | + measure_callbacks=[ |
| 203 | + auto_scheduler.RecordToFile(log_file), |
| 204 | + ], |
| 205 | + ), |
| 206 | + adaptive_training=ARGS.adaptive_training, |
212 | 207 | ) |
213 | 208 |
|
214 | | - def f_timer(rt_mod, dev, input_data): |
215 | | - # pylint: disable=import-outside-toplevel |
216 | | - from tvm.contrib.graph_executor import GraphModule |
217 | | - |
218 | | - # pylint: enable=import-outside-toplevel |
219 | | - |
220 | | - mod = GraphModule(rt_mod["default"](dev)) |
221 | | - for input_name, input_value in input_data.items(): |
222 | | - mod.set_input(input_name, input_value) |
223 | | - ftimer = mod.module.time_evaluator( |
224 | | - "run", |
225 | | - dev, |
226 | | - min_repeat_ms=500, |
227 | | - repeat=3, |
228 | | - ) |
229 | | - results = list(np.array(ftimer().results) * 1000.0) # type: ignore |
230 | | - print("Running time in time_evaluator: ", results) |
| 209 | + relay_build = {"graph": relay.build, "vm": relay.vm.compile}[ARGS.backend] |
| 210 | + with auto_scheduler.ApplyHistoryBest(log_file): |
| 211 | + with tvm.transform.PassContext( |
| 212 | + opt_level=3, |
| 213 | + config={"relay.backend.use_auto_scheduler": True}, |
| 214 | + ): |
| 215 | + lib = relay_build( |
| 216 | + mod, |
| 217 | + target=ARGS.target, |
| 218 | + params=params, |
| 219 | + ) |
| 220 | + print("Tuning Time:") |
| 221 | + print(profiler.table()) |
231 | 222 |
|
232 | 223 | run_module_via_rpc( |
233 | 224 | rpc_config=ARGS.rpc_config, |
234 | 225 | lib=lib, |
235 | 226 | dev_type=ARGS.target.kind.name, |
236 | 227 | args=input_data, |
237 | | - continuation=f_timer, |
238 | | - ) |
239 | | - |
240 | | - def f_per_layer(rt_mod, dev, input_data): |
241 | | - # pylint: disable=import-outside-toplevel |
242 | | - from tvm.contrib.debugger.debug_executor import create |
243 | | - |
244 | | - # pylint: enable=import-outside-toplevel |
245 | | - mod = create(graph, rt_mod, dev) |
246 | | - for input_name, input_value in input_data.items(): |
247 | | - mod.set_input(input_name, input_value) |
248 | | - graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]] |
249 | | - graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000) |
250 | | - print("|graph_nodes| = ", len(graph_nodes)) |
251 | | - print("|graph_time| = ", len(graph_time)) |
252 | | - graph_nodes_time = {k: float(v) for k, v in zip(graph_nodes, graph_time)} |
253 | | - for k, v in graph_nodes_time.items(): |
254 | | - print(f"{k} : {v:.3f}") |
255 | | - |
256 | | - run_module_via_rpc( |
257 | | - rpc_config=ARGS.rpc_config, |
258 | | - lib=rt_mod, |
259 | | - dev_type=ARGS.target.kind.name, |
260 | | - args=input_data, |
261 | | - continuation=f_per_layer, |
| 228 | + continuation=create_timer(ARGS.backend), |
| 229 | + backend=ARGS.backend, |
262 | 230 | ) |
263 | 231 |
|
264 | 232 |
|
|
0 commit comments