Skip to content

Commit b35fc83

Browse files
zxybazhjunrushaospectrometerHBHMasterJH5574jinhongyii
authored
[M3c][MetaScheduler] Add More Measure Callbacks. (#9780)
* Add measure callbacks. Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> * Fix comments. Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Siyuan Feng <[email protected]>
1 parent d026d06 commit b35fc83

File tree

16 files changed

+1108
-0
lines changed

16 files changed

+1108
-0
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#ifndef TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
21+
#define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
22+
23+
#include <tvm/meta_schedule/builder.h>
24+
#include <tvm/meta_schedule/runner.h>
25+
#include <tvm/meta_schedule/search_strategy.h>
26+
#include <tvm/meta_schedule/tune_context.h>
27+
28+
namespace tvm {
29+
namespace meta_schedule {
30+
31+
class TaskScheduler;
32+
33+
/*! \brief Rules to apply after measure results is available. */
34+
class MeasureCallbackNode : public runtime::Object {
35+
public:
36+
/*! \brief Virtual destructor. */
37+
virtual ~MeasureCallbackNode() = default;
38+
39+
void VisitAttrs(tvm::AttrVisitor* v) {}
40+
41+
/*!
42+
* \brief Apply a measure callback rule with given arguments.
43+
* \param task_scheduler The task scheduler.
44+
* \param task_id The id of the task (tune context) to apply measure callbacks.
45+
* \param measure_candidates The measure candidates.
46+
* \param builder_results The builder results by building the measure candidates.
47+
* \param runner_results The runner results by running the built measure candidates.
48+
*/
49+
virtual void Apply(const TaskScheduler& task_scheduler, //
50+
int task_id, //
51+
const Array<MeasureCandidate>& measure_candidates, //
52+
const Array<BuilderResult>& builder_results, //
53+
const Array<RunnerResult>& runner_results) = 0;
54+
55+
static constexpr const char* _type_key = "meta_schedule.MeasureCallback";
56+
TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object);
57+
};
58+
59+
/*! \brief The measure callback with customized methods on the python-side. */
60+
class PyMeasureCallbackNode : public MeasureCallbackNode {
61+
public:
62+
/*!
63+
* \brief Apply a measure callback to the given schedule.
64+
* \param task_scheduler The task scheduler.
65+
* \param tasks The list of tune context to process.
66+
* \param measure_candidates The measure candidates.
67+
* \param builds The builder results by building the measure candidates.
68+
* \param results The runner results by running the built measure candidates.
69+
* \return Whether the measure callback was successfully applied.
70+
*/
71+
using FApply =
72+
runtime::TypedPackedFunc<void(const TaskScheduler& task_scheduler, //
73+
int task_id, //
74+
const Array<MeasureCandidate>& measure_candidates, //
75+
const Array<BuilderResult>& builds, //
76+
const Array<RunnerResult>& results)>;
77+
/*!
78+
* \brief Get the measure callback function as string with name.
79+
* \return The string of the measure callback function.
80+
*/
81+
using FAsString = runtime::TypedPackedFunc<String()>;
82+
83+
/*! \brief The packed function to the `Apply` function. */
84+
FApply f_apply;
85+
/*! \brief The packed function to the `AsString` function. */
86+
FAsString f_as_string;
87+
88+
void VisitAttrs(tvm::AttrVisitor* v) {
89+
// `f_apply` is not visited
90+
// `f_as_string` is not visited
91+
}
92+
93+
void Apply(const TaskScheduler& task_scheduler, //
94+
int task_id, //
95+
const Array<MeasureCandidate>& measure_candidates, //
96+
const Array<BuilderResult>& builds, //
97+
const Array<RunnerResult>& results) final {
98+
ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!";
99+
return this->f_apply(task_scheduler, task_id, measure_candidates, builds, results);
100+
}
101+
102+
static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback";
103+
TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode);
104+
};
105+
106+
/*!
107+
* \brief Managed reference to MeasureCallbackNode
108+
* \sa MeasureCallbackNode
109+
*/
110+
class MeasureCallback : public runtime::ObjectRef {
111+
public:
112+
/*!
113+
* \brief Create a measure callback that adds the measurement results into the database
114+
* \return The measure callback created.
115+
*/
116+
TVM_DLL static MeasureCallback AddToDatabase();
117+
/*!
118+
* \brief Create a measure callback that removes the build artifacts from the disk
119+
* \return The measure callback created.
120+
*/
121+
TVM_DLL static MeasureCallback RemoveBuildArtifact();
122+
/*!
123+
* \brief Create a measure callback that echos the statistics of the tuning process to the console
124+
* \return The measure callback created.
125+
*/
126+
TVM_DLL static MeasureCallback EchoStatistics();
127+
/*!
128+
* \brief Create a measure callback that updates the cost model with measurement result.
129+
* \return The measure callback created.
130+
*/
131+
TVM_DLL static MeasureCallback UpdateCostModel();
132+
/*!
133+
* \brief Create a measure callback with customized methods on the python-side.
134+
* \param f_apply The packed function of `Apply`.
135+
* \param f_as_string The packed function of `AsString`.
136+
* \return The measure callback created.
137+
*/
138+
TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply,
139+
PyMeasureCallbackNode::FAsString f_as_string);
140+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode);
141+
};
142+
143+
} // namespace meta_schedule
144+
} // namespace tvm
145+
146+
#endif // TVM_META_SCHEDULE_MEASURE_CALLBACK_H_

include/tvm/meta_schedule/task_scheduler.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ class TaskSchedulerNode : public runtime::Object {
7373
Runner runner{nullptr};
7474
/*! \brief The database of the scheduler. */
7575
Database database{nullptr};
76+
/*! \brief The cost model of the scheduler. */
77+
Optional<CostModel> cost_model;
78+
/*! \brief The list of measure callbacks of the scheduler. */
79+
Array<MeasureCallback> measure_callbacks;
7680

7781
/*! \brief The default desctructor. */
7882
virtual ~TaskSchedulerNode() = default;
@@ -82,6 +86,8 @@ class TaskSchedulerNode : public runtime::Object {
8286
v->Visit("builder", &builder);
8387
v->Visit("runner", &runner);
8488
v->Visit("database", &database);
89+
v->Visit("cost_model", &cost_model);
90+
v->Visit("measure_callbacks", &measure_callbacks);
8591
}
8692

8793
/*! \brief Auto-tuning. */

include/tvm/meta_schedule/tune_context.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#define TVM_META_SCHEDULE_TUNE_CONTEXT_H_
2121

2222
#include <tvm/ir/module.h>
23+
#include <tvm/meta_schedule/schedule_rule.h>
2324
#include <tvm/meta_schedule/space_generator.h>
2425
#include <tvm/support/random_engine.h>
2526
#include <tvm/target/target.h>
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""
18+
The tvm.meta_schedule.measure_callback package.
19+
"""
20+
from .measure_callback import MeasureCallback, PyMeasureCallback
21+
from .add_to_database import AddToDatabase
22+
from .echo_statistics import EchoStatistics
23+
from .remove_build_artifact import RemoveBuildArtifact
24+
from .update_cost_model import UpdateCostModel
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""A callback that adds the measurement results into the database"""
18+
from tvm._ffi import register_object
19+
20+
from .. import _ffi_api
21+
from .measure_callback import MeasureCallback
22+
23+
24+
@register_object("meta_schedule.AddToDatabase")
25+
class AddToDatabase(MeasureCallback):
26+
def __init__(self) -> None:
27+
"""A callback that adds the measurement results into the database"""
28+
self.__init_handle_by_constructor__(
29+
_ffi_api.MeasureCallbackAddToDatabase, # type: ignore # pylint: disable=no-member
30+
)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""A callback that echos the statistics of the tuning process to the console"""
18+
from tvm._ffi import register_object
19+
20+
from .. import _ffi_api
21+
from .measure_callback import MeasureCallback
22+
23+
24+
@register_object("meta_schedule.EchoStatistics")
25+
class EchoStatistics(MeasureCallback):
26+
def __init__(self) -> None:
27+
"""A callback that echos the statistics of the tuning process to the console"""
28+
self.__init_handle_by_constructor__(
29+
_ffi_api.MeasureCallbackEchoStatistics, # type: ignore # pylint: disable=no-member
30+
)
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Meta Schedule MeasureCallback."""
18+
19+
from typing import List, TYPE_CHECKING
20+
21+
from tvm._ffi import register_object
22+
from tvm.runtime import Object
23+
24+
from .. import _ffi_api
25+
from ..builder import BuilderResult
26+
from ..runner import RunnerResult
27+
from ..search_strategy import MeasureCandidate
28+
from ..utils import _get_hex_address, check_override
29+
30+
if TYPE_CHECKING:
31+
from ..task_scheduler import TaskScheduler
32+
33+
34+
@register_object("meta_schedule.MeasureCallback")
35+
class MeasureCallback(Object):
36+
"""Rules to apply after measure results is available."""
37+
38+
def apply(
39+
self,
40+
task_scheduler: "TaskScheduler",
41+
task_id: int,
42+
measure_candidates: List[MeasureCandidate],
43+
builder_results: List[BuilderResult],
44+
runner_results: List[RunnerResult],
45+
) -> None:
46+
"""Apply a measure callback to the given schedule.
47+
48+
Parameters
49+
----------
50+
task_scheduler: TaskScheduler
51+
The task scheduler.
52+
task_id: int
53+
The task id.
54+
measure_candidates: List[MeasureCandidate]
55+
The measure candidates.
56+
builder_results: List[BuilderResult]
57+
The builder results by building the measure candidates.
58+
runner_results: List[RunnerResult]
59+
The runner results by running the built measure candidates.
60+
"""
61+
return _ffi_api.MeasureCallbackApply( # type: ignore # pylint: disable=no-member
62+
self,
63+
task_scheduler,
64+
task_id,
65+
measure_candidates,
66+
builder_results,
67+
runner_results,
68+
)
69+
70+
71+
@register_object("meta_schedule.PyMeasureCallback")
72+
class PyMeasureCallback(MeasureCallback):
73+
"""An abstract MeasureCallback with customized methods on the python-side."""
74+
75+
def __init__(self):
76+
"""Constructor."""
77+
78+
@check_override(self.__class__, MeasureCallback)
79+
def f_apply(
80+
task_scheduler: "TaskScheduler",
81+
task_id: int,
82+
measure_candidates: List[MeasureCandidate],
83+
builder_results: List[BuilderResult],
84+
runner_results: List[RunnerResult],
85+
) -> None:
86+
return self.apply(
87+
task_scheduler,
88+
task_id,
89+
measure_candidates,
90+
builder_results,
91+
runner_results,
92+
)
93+
94+
def f_as_string() -> str:
95+
return str(self)
96+
97+
self.__init_handle_by_constructor__(
98+
_ffi_api.MeasureCallbackPyMeasureCallback, # type: ignore # pylint: disable=no-member
99+
f_apply,
100+
f_as_string,
101+
)
102+
103+
def __str__(self) -> str:
104+
return f"PyMeasureCallback({_get_hex_address(self.handle)})"
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""A callback that removes the build artifacts from the disk"""
18+
from tvm._ffi import register_object
19+
20+
from .. import _ffi_api
21+
from .measure_callback import MeasureCallback
22+
23+
24+
@register_object("meta_schedule.RemoveBuildArtifact")
25+
class RemoveBuildArtifact(MeasureCallback):
26+
def __init__(self) -> None:
27+
"""A callback that removes the build artifacts from the disk"""
28+
self.__init_handle_by_constructor__(
29+
_ffi_api.MeasureCallbackRemoveBuildArtifact, # type: ignore # pylint: disable=no-member
30+
)

0 commit comments

Comments
 (0)