|
27 | 27 | from tvm.rpc.server import Server |
28 | 28 | from tvm.rpc.tracker import Tracker |
29 | 29 | from tvm.meta_schedule.logging import get_logger |
30 | | -from tvm.meta_schedule.utils import ( |
31 | | - cpu_count, |
32 | | - derived_object, |
33 | | -) |
| 30 | +from tvm.meta_schedule.utils import cpu_count, derived_object |
34 | 31 | from tvm.meta_schedule.runner.config import EvaluatorConfig, RPCConfig |
35 | 32 | from tvm.meta_schedule.runner import PyRunner, RunnerFuture, RunnerInput |
36 | 33 | from tvm.meta_schedule.runner.rpc_runner import RPCRunnerFuture |
@@ -159,51 +156,44 @@ def _worker_func( |
159 | 156 | def get_rpc_runner_micro( |
160 | 157 | platform, |
161 | 158 | options, |
162 | | - session_timeout_sec: int = 300, |
163 | | - number: int = 3, |
164 | | - repeat: int = 1, |
165 | | - min_repeat_ms: int = 100, |
| 159 | + rpc_config: RPCConfig = None, |
| 160 | + evaluator_config: EvaluatorConfig = None, |
| 161 | + session_timeout_sec=300, |
166 | 162 | ): |
167 | 163 | """Parameters |
168 | 164 | ---------- |
169 | 165 | platform: str |
170 | 166 | The platform used for project generation. |
171 | 167 | project_options: dict |
172 | 168 | The options for the generated micro project. |
| 169 | + rpc_config: RPCConfig |
| 170 | + The rpc configuration. |
| 171 | + evaluator_config: EvaluatorConfig |
| 172 | + The evaluator configuration. |
173 | 173 | session_timeout_sec: int |
174 | 174 | The session timeout. if the number of candidates sent to runner is larger |
175 | 175 | than the runner workers, increase the timeout. |
176 | | - number: int |
177 | | - The number of times to run the evaluator function for taking average. |
178 | | - We call these runs as one `repeat` of measurement. |
179 | | - repeat: int |
180 | | - The number of times to repeat the measurement. |
181 | | - In total, the function will be invoked (1 + number x repeat) times, |
182 | | - where the first one is warm up and will be discarded. |
183 | | - The returned result contains `repeat` costs, |
184 | | - each of which is an average of `number` costs. |
185 | | - min_repeat_ms: int |
186 | | - Minimum repeat time in ms. if the execution latency is too short, |
187 | | - increase the number of runs to the given time (in ms) to reduce the measurement error. |
188 | 176 | """ |
189 | | - tracker_host = "127.0.0.1" |
190 | | - tracker_port = 9000 |
191 | | - tracker_key = "$local$device$%d" % tracker_port |
192 | | - rpc_config = RPCConfig( |
193 | | - tracker_host=tracker_host, |
194 | | - tracker_port=tracker_port, |
195 | | - tracker_key=tracker_key, |
196 | | - session_priority=0, |
197 | | - session_timeout_sec=session_timeout_sec, |
198 | | - ) |
199 | | - rpc_config = RPCConfig._normalized(rpc_config) |
200 | | - tracker_port_end = 10000 |
201 | | - evaluator_config = EvaluatorConfig( |
202 | | - number=number, |
203 | | - repeat=repeat, |
204 | | - min_repeat_ms=min_repeat_ms, |
205 | | - enable_cpu_cache_flush=False, |
206 | | - ) |
| 177 | + if rpc_config is None: |
| 178 | + tracker_host = "127.0.0.1" |
| 179 | + tracker_port = 9000 |
| 180 | + tracker_key = "$local$device$%d" % tracker_port |
| 181 | + rpc_config = RPCConfig( |
| 182 | + tracker_host=tracker_host, |
| 183 | + tracker_port=tracker_port, |
| 184 | + tracker_key=tracker_key, |
| 185 | + session_priority=0, |
| 186 | + session_timeout_sec=session_timeout_sec, |
| 187 | + ) |
| 188 | + tracker_port_end = rpc_config.tracker_port + 1000 |
| 189 | + |
| 190 | + if evaluator_config is None: |
| 191 | + evaluator_config = EvaluatorConfig( |
| 192 | + number=3, |
| 193 | + repeat=1, |
| 194 | + min_repeat_ms=100, |
| 195 | + enable_cpu_cache_flush=False, |
| 196 | + ) |
207 | 197 |
|
208 | 198 | tracker = Tracker( |
209 | 199 | port=rpc_config.tracker_port, |
|
0 commit comments