Skip to content

Commit c0396fa

Browse files
committed
addressing comments
1 parent ba09232 commit c0396fa

File tree

4 files changed

+42
-50
lines changed

4 files changed

+42
-50
lines changed

python/tvm/contrib/micro/meta_schedule/local_builder_micro.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
"""Local builder for microTVM projects.that compile on the local host"""
17+
"""Local builder for microTVM projects that compile on the local host"""
1818

1919
import os
2020
import tempfile
@@ -34,10 +34,10 @@
3434
def get_micro_local_builder():
3535
"""Return micro-compatible Builder for meta schedule."""
3636

37-
def micro_build(
37+
def _micro_build(
3838
mod: IRModule, target: Target, _params: Optional[Dict[str, NDArray]]
3939
) -> OperatorModule:
40-
"""build function for micro targets.
40+
"""Build function for micro targets.
4141
4242
Parameters
4343
----------
@@ -54,7 +54,8 @@ def micro_build(
5454
The built Module.
5555
"""
5656

57-
# Note: changing the global symbol is necessary for micro targets,
57+
# Note: tvm_build assigns "global_symbol" to the name of generated C function
58+
# changing it is necessary for micro targets,
5859
# since the generated projects already include a main function.
5960
prim_func = mod["main"].with_attr("global_symbol", "default_function")
6061
mod = IRModule({"main": prim_func})
@@ -63,8 +64,8 @@ def micro_build(
6364
rt_mod = tvm_build(mod, target=target, runtime=runtime)
6465
return rt_mod
6566

66-
def micro_export(mod: OperatorModule) -> str:
67-
"""export function for micro targets.
67+
def _micro_export(mod: OperatorModule) -> str:
68+
"""Export function for micro targets.
6869
6970
Parameters
7071
----------
@@ -80,4 +81,4 @@ def micro_export(mod: OperatorModule) -> str:
8081
micro.export_model_library_format(mod, artifact_path)
8182
return artifact_path
8283

83-
return LocalBuilder(f_build=micro_build, f_export=micro_export)
84+
return LocalBuilder(f_build=_micro_build, f_export=_micro_export)

python/tvm/contrib/micro/meta_schedule/rpc_runner_micro.py

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@
2727
from tvm.rpc.server import Server
2828
from tvm.rpc.tracker import Tracker
2929
from tvm.meta_schedule.logging import get_logger
30-
from tvm.meta_schedule.utils import (
31-
cpu_count,
32-
derived_object,
33-
)
30+
from tvm.meta_schedule.utils import cpu_count, derived_object
3431
from tvm.meta_schedule.runner.config import EvaluatorConfig, RPCConfig
3532
from tvm.meta_schedule.runner import PyRunner, RunnerFuture, RunnerInput
3633
from tvm.meta_schedule.runner.rpc_runner import RPCRunnerFuture
@@ -159,51 +156,44 @@ def _worker_func(
159156
def get_rpc_runner_micro(
160157
platform,
161158
options,
162-
session_timeout_sec: int = 300,
163-
number: int = 3,
164-
repeat: int = 1,
165-
min_repeat_ms: int = 100,
159+
rpc_config: RPCConfig = None,
160+
evaluator_config: EvaluatorConfig = None,
161+
session_timeout_sec=300,
166162
):
167163
"""Parameters
168164
----------
169165
platform: str
170166
The platform used for project generation.
171167
project_options: dict
172168
The options for the generated micro project.
169+
rpc_config: RPCConfig
170+
The rpc configuration.
171+
evaluator_config: EvaluatorConfig
172+
The evaluator configuration.
173173
session_timeout_sec: int
174174
The session timeout. if the number of candidates sent to runner is larger
175175
than the runner workers, increase the timeout.
176-
number: int
177-
The number of times to run the evaluator function for taking average.
178-
We call these runs as one `repeat` of measurement.
179-
repeat: int
180-
The number of times to repeat the measurement.
181-
In total, the function will be invoked (1 + number x repeat) times,
182-
where the first one is warm up and will be discarded.
183-
The returned result contains `repeat` costs,
184-
each of which is an average of `number` costs.
185-
min_repeat_ms: int
186-
Minimum repeat time in ms. if the execution latency is too short,
187-
increase the number of runs to the given time (in ms) to reduce the measurement error.
188176
"""
189-
tracker_host = "127.0.0.1"
190-
tracker_port = 9000
191-
tracker_key = "$local$device$%d" % tracker_port
192-
rpc_config = RPCConfig(
193-
tracker_host=tracker_host,
194-
tracker_port=tracker_port,
195-
tracker_key=tracker_key,
196-
session_priority=0,
197-
session_timeout_sec=session_timeout_sec,
198-
)
199-
rpc_config = RPCConfig._normalized(rpc_config)
200-
tracker_port_end = 10000
201-
evaluator_config = EvaluatorConfig(
202-
number=number,
203-
repeat=repeat,
204-
min_repeat_ms=min_repeat_ms,
205-
enable_cpu_cache_flush=False,
206-
)
177+
if rpc_config is None:
178+
tracker_host = "127.0.0.1"
179+
tracker_port = 9000
180+
tracker_key = "$local$device$%d" % tracker_port
181+
rpc_config = RPCConfig(
182+
tracker_host=tracker_host,
183+
tracker_port=tracker_port,
184+
tracker_key=tracker_key,
185+
session_priority=0,
186+
session_timeout_sec=session_timeout_sec,
187+
)
188+
tracker_port_end = rpc_config.tracker_port + 1000
189+
190+
if evaluator_config is None:
191+
evaluator_config = EvaluatorConfig(
192+
number=3,
193+
repeat=1,
194+
min_repeat_ms=100,
195+
enable_cpu_cache_flush=False,
196+
)
207197

208198
tracker = Tracker(
209199
port=rpc_config.tracker_port,

python/tvm/contrib/micro/meta_schedule/test_autotune_ms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_micro_tuning_with_meta_schedule(platform, options):
105105
platform=platform, options=options, session_timeout_sec=120
106106
) as runner:
107107
with ms.Profiler() as profiler:
108-
db: tvm.runtime.Module = ms.relay_integration.tune_relay(
108+
db: ms.Database = ms.relay_integration.tune_relay(
109109
mod=mod,
110110
params=params,
111111
target=target,

src/meta_schedule/schedule_rule/schedule_rule.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,11 @@ Array<ScheduleRule> ScheduleRule::DefaultMicro() {
261261
/*require_injective=*/true,
262262
/*require_ordered=*/true,
263263
/*disallow_op=*/Array<String>{"tir.exp"}),
264-
ScheduleRule::MultiLevelTilingWideVector(
265-
/*structure=*/"SRSRS",
266-
/*vector_length_in_bits=*/1024,
267-
/*max_innermost_factor=*/Integer(128),
264+
ScheduleRule::MultiLevelTiling(
265+
/*structure=*/"SSRSRS",
266+
/*tile_binds=*/NullOpt,
267+
/*max_innermost_factor=*/Integer(64),
268+
/*vector_load_lens=*/NullOpt,
268269
/*reuse_read=*/NullOpt,
269270
/*reuse_write=*/
270271
Map<String, ObjectRef>{{"req", String("may")},

0 commit comments

Comments
 (0)