Skip to content

Commit 56a186d

Browse files
[MetaSchedule] Developer Ergonomics Enhancement
Per discussion with @Kathryn-cat - [x] Move `initialize_with_tune_context` as private API `_initialize_with_tune_context`, and encourage using `TuneContext.initialize` - [x] Instead of using bunch of import statements, encourage using `ms.xxx` as the prefix (e.g. `ms.database.MemoryDatabase`) to organize things better - [x] Move `DefaultLLVM`, `DefaultCUDA` to a separate file and make them more discoverable - [x] Move `DummyDatabase` to `tvm.meta_schedule.database.MemoryDatabase` given it's actually useful - [x] Delegate class members' methods in `TuneContext`, for example, having `TuneContext.generste_design_space` from `TuneContext.space_generator.generste_design_space` Next PR: - Allow using a string `"default"` in `TuneContext` as well as `tune_relay/tir/te` to quickly specify a set of target-specific rules - Add `TuneContext.tune` to allow directly tuning without task scheduler. - Enhance detection of `ScheduleFn` in `TuneContext` to make it easier for users to quickly try out template-driven scheduling on TIR. Co-Authored-By: Kathryn (Jinqi) Chen <[email protected]>
1 parent 52d90da commit 56a186d

File tree

52 files changed

+1096
-886
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+1096
-886
lines changed

include/tvm/meta_schedule/search_strategy.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,10 @@ class SearchStrategyNode : public runtime::Object {
113113

114114
/*!
115115
* \brief Update the search strategy with measurement results.
116-
* \param context The tuning context.
117116
* \param measure_candidates The candidates to be measured.
118117
* \param results The measurement results from the runner.
119118
*/
120-
virtual void NotifyRunnerResults(const TuneContext& context,
121-
const Array<MeasureCandidate>& measure_candidates,
119+
virtual void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
122120
const Array<RunnerResult>& results) = 0;
123121

124122
static constexpr const char* _type_key = "meta_schedule.SearchStrategy";
@@ -150,8 +148,8 @@ class PySearchStrategyNode : public SearchStrategyNode {
150148
* \brief The function type of `NotifyRunnerResults` method.
151149
* \param results The measurement results from the runner.
152150
*/
153-
using FNotifyRunnerResults = runtime::TypedPackedFunc<void(
154-
const TuneContext&, const Array<MeasureCandidate>&, const Array<RunnerResult>&)>;
151+
using FNotifyRunnerResults =
152+
runtime::TypedPackedFunc<void(const Array<MeasureCandidate>&, const Array<RunnerResult>&)>;
155153

156154
/*! \brief The packed function to the `InitializeWithTuneContext` method. */
157155
FInitializeWithTuneContext f_initialize_with_tune_context;
@@ -177,8 +175,7 @@ class PySearchStrategyNode : public SearchStrategyNode {
177175
const Optional<CostModel>& cost_model) final;
178176
void PostTuning() final;
179177
Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final;
180-
void NotifyRunnerResults(const TuneContext& context,
181-
const Array<MeasureCandidate>& measure_candidates,
178+
void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
182179
const Array<RunnerResult>& results);
183180

184181
static constexpr const char* _type_key = "meta_schedule.PySearchStrategy";

include/tvm/meta_schedule/tune_context.h

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ namespace tvm {
4242
namespace meta_schedule {
4343

4444
class TaskSchedulerNode;
45+
class MeasureCallback;
4546

4647
/*! \brief The auto tuning context. */
4748
class TuneContextNode : public runtime::Object {
@@ -69,7 +70,10 @@ class TuneContextNode : public runtime::Object {
6970
/*! \brief The number of threads to be used. */
7071
int num_threads;
7172

72-
/*! \brief Whether the tuning task has been stopped or finished. */
73+
/*!
74+
* \brief Whether the tuning task has been stopped or finished.
75+
* TODO(@junrushao1994): move to TaskScheduler
76+
*/
7377
bool is_terminated;
7478
/*! \brief The measure candidates. */
7579
Optional<Array<MeasureCandidate>> measure_candidates;
@@ -87,18 +91,36 @@ class TuneContextNode : public runtime::Object {
8791
v->Visit("postprocs", &postprocs);
8892
v->Visit("mutator_probs", &mutator_probs);
8993
v->Visit("task_name", &task_name);
94+
// `logging_func` is not visited
9095
v->Visit("rand_state", &rand_state);
9196
v->Visit("num_threads", &num_threads);
9297
v->Visit("is_terminated", &is_terminated);
98+
v->Visit("measure_candidates", &measure_candidates);
9399
v->Visit("builder_results", &builder_results);
94100
v->Visit("runner_futures", &runner_futures);
95-
v->Visit("measure_candidates", &measure_candidates);
96-
// `logging_func` is not visited
97101
}
98102

99103
/*! \brief Initialize members that needs initialization with tune context. */
100104
void Initialize();
101-
105+
/*! \brief Set the measure candidates from the SearchStrategy */
106+
void _SetMeasureCandidates(const Array<MeasureCandidate>& candidates);
107+
/*!
108+
* \brief Send the measure candidates to builder.
109+
* \param builder The builder to send the candidates to.
110+
*/
111+
void _SendToBuilder(const Builder& builder);
112+
/*!
113+
* \brief Send the built measure candidates to runner.
114+
* \param runner The runner to send the candidates to.
115+
*/
116+
void _SendToRunner(const Runner& runner);
117+
/*!
118+
* \brief Join the running tasks.
119+
* \returns The results from the runner
120+
*/
121+
Array<RunnerResult> _Join();
122+
/*! \brief Set `measure_candidates`, `builder_results` and `runner_futures` to null. */
123+
void _ClearMeasureState();
102124
static constexpr const char* _type_key = "meta_schedule.TuneContext";
103125
TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object);
104126
};

python/tvm/meta_schedule/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
builder,
2121
cost_model,
2222
database,
23+
default_config,
2324
feature_extractor,
25+
measure_callback,
2426
mutator,
2527
postproc,
2628
runner,
@@ -32,5 +34,6 @@
3234
from .extracted_task import ExtractedTask
3335
from .relay_integration import extract_task_from_relay
3436
from .search_strategy import MeasureCandidate
35-
from .tune import TuneConfig, tune_relay, tune_te, tune_tir
37+
from .tune import TuneConfig, tune_extracted_tasks, tune_relay, tune_te, tune_tir
3638
from .tune_context import TuneContext
39+
from .utils import derived_object

python/tvm/meta_schedule/database/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
"""
2121
from .database import Database, PyDatabase, TuningRecord, Workload
2222
from .json_database import JSONDatabase
23+
from .memory_database import MemoryDatabase

python/tvm/meta_schedule/database/database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
"""Tuning record database"""
17+
"""TuningRecord database"""
1818
from typing import Any, Callable, List, Optional
1919

2020
from tvm._ffi import register_object
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""A database that stores TuningRecords in memory"""
18+
from typing import List
19+
20+
from ...ir import IRModule, structural_equal
21+
from ..utils import derived_object
22+
from .database import PyDatabase, TuningRecord, Workload
23+
24+
25+
@derived_object
26+
class MemoryDatabase(PyDatabase):
27+
"""An in-memory database based on python list for testing."""
28+
29+
def __init__(self):
30+
super().__init__()
31+
self.records = []
32+
self.workload_reg = []
33+
34+
def has_workload(self, mod: IRModule) -> bool:
35+
for workload in self.workload_reg:
36+
if structural_equal(workload.mod, mod):
37+
return True
38+
return False
39+
40+
def commit_tuning_record(self, record: TuningRecord) -> None:
41+
self.records.append(record)
42+
43+
def commit_workload(self, mod: IRModule) -> Workload:
44+
for workload in self.workload_reg:
45+
if structural_equal(workload.mod, mod):
46+
return workload
47+
workload = Workload(mod)
48+
self.workload_reg.append(workload)
49+
return workload
50+
51+
def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
52+
return list(
53+
filter(
54+
lambda x: x.workload == workload,
55+
sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
56+
)
57+
)[: int(top_k)]
58+
59+
def __len__(self) -> int:
60+
return len(self.records)
61+
62+
def print_results(self) -> None:
63+
print("\n".join([str(r) for r in self.records]))

0 commit comments

Comments
 (0)