Skip to content

Commit b8cfc67

Browse files
zxybazhjunrushao
authored andcommitted
[MetaSchedule] Add Profiler Support For Tuning Efficiency Optimization
1 parent 8341e33 commit b8cfc67

File tree

7 files changed

+398
-2
lines changed

7 files changed

+398
-2
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#ifndef TVM_META_SCHEDULE_PROFILER_H_
20+
#define TVM_META_SCHEDULE_PROFILER_H_
21+
22+
#include <tvm/ir/module.h>
23+
#include <tvm/node/reflection.h>
24+
#include <tvm/runtime/container/array.h>
25+
#include <tvm/runtime/container/optional.h>
26+
#include <tvm/runtime/container/string.h>
27+
#include <tvm/runtime/object.h>
28+
#include <tvm/runtime/packed_func.h>
29+
#include <tvm/target/target.h>
30+
31+
#include <utility>
32+
#include <vector>
33+
34+
namespace tvm {
35+
namespace meta_schedule {
36+
37+
struct ScopedTimer {
38+
std::function<void()> func;
39+
explicit ScopedTimer(std::function<void()> func) : func(func) {}
40+
~ScopedTimer() { func(); }
41+
};
42+
43+
/*!
44+
* \brief A profiler to count tuning time cost in different parts.
45+
*/
46+
class ProfilerNode : public runtime::Object {
47+
public:
48+
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("stats", &stats); }
49+
50+
/*!
51+
* \brief Profile the time usage in the given scope in the given name.
52+
* \param name Name for the scope.
53+
* \return A scope timer for time profiling.
54+
*/
55+
static ScopedTimer TimeScope(String name);
56+
57+
/*!
58+
* \brief Get the profiling results.
59+
* \return The tuning profiling results as a dict.
60+
*/
61+
Map<String, FloatImm> Get() const { return stats; }
62+
63+
/*!
64+
* \brief Start the timer for a new context.
65+
* \param name Name of the context.
66+
*/
67+
void StartContextTimer(String name);
68+
69+
/*! \brief End the timer for the most recent context. */
70+
void EndContextTimer();
71+
72+
static constexpr const char* _type_key = "meta_schedule.Profiler";
73+
TVM_DECLARE_FINAL_OBJECT_INFO(ProfilerNode, runtime::Object);
74+
75+
protected:
76+
Map<String, FloatImm> stats;
77+
std::vector<std::pair<String, std::chrono::time_point<std::chrono::high_resolution_clock>>> stack;
78+
};
79+
80+
/*!
81+
* \brief Managed reference to ProfilerNode
82+
* \sa ProfilerNode
83+
*/
84+
class Profiler : public runtime::ObjectRef {
85+
public:
86+
Profiler();
87+
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Profiler, runtime::ObjectRef, ProfilerNode);
88+
89+
protected:
90+
friend class ProfilerInternal;
91+
92+
/*! \brief Entering the scope of the context manager */
93+
void EnterWithScope();
94+
/*! \brief Exiting the scope of the context manager */
95+
void ExitWithScope();
96+
};
97+
98+
struct ProfilerThreadLocalEntry {
99+
Optional<Profiler> ctx;
100+
};
101+
using ProfilerThreadLocalStore = dmlc::ThreadLocalStore<ProfilerThreadLocalEntry>;
102+
103+
} // namespace meta_schedule
104+
} // namespace tvm
105+
106+
#endif // TVM_META_SCHEDULE_PROFILER_H_

python/tvm/meta_schedule/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
search_strategy,
3131
space_generator,
3232
)
33+
from .profiler import Profiler
3334
from .apply_history_best import ApplyHistoryBest
3435
from .extracted_task import ExtractedTask
3536
from .relay_integration import extract_task_from_relay
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""A context manager that profiles tuning time cost for different parts."""
18+
from __future__ import annotations
19+
20+
import logging
21+
from typing import Dict
22+
23+
from tvm._ffi import register_object
24+
from tvm.runtime import Object
25+
26+
from . import _ffi_api
27+
28+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
29+
30+
31+
class TimeItContext:
32+
"""The context to profile given scope."""
33+
34+
profiler: Profiler
35+
name: str
36+
37+
def __init__(self, profiler: "Profiler", name: str):
38+
self.profiler = profiler
39+
self.name = name
40+
41+
def __enter__(self):
42+
_ffi_api.ProfilerStartContextTimer(self.profiler, self.name) # type: ignore # pylint: disable=no-member
43+
return self
44+
45+
def __exit__(self, exctype, excinst, exctb):
46+
_ffi_api.ProfilerEndContextTimer(self.profiler) # type: ignore # pylint: disable=no-member
47+
48+
49+
@register_object("meta_schedule.Profiler")
50+
class Profiler(Object):
51+
"""A profiler to count tuning time cost in different parts."""
52+
53+
def __init__(self) -> None:
54+
self.__init_handle_by_constructor__(
55+
_ffi_api.Profiler, # type: ignore # pylint: disable=no-member
56+
)
57+
58+
def get(self) -> Dict[str, float]:
59+
"""Get the profiling results in minutes"""
60+
return _ffi_api.ProfilerGet(self) # type: ignore # pylint: disable=no-member
61+
62+
def timeit(self, name: str) -> TimeItContext:
63+
return TimeItContext(self, name)
64+
65+
def __enter__(self) -> "Profiler":
66+
"""Entering the scope of the context manager"""
67+
_ffi_api.ProfilerEnterScope(self) # type: ignore # pylint: disable=no-member
68+
return self
69+
70+
def __exit__(self, ptype, value, trace) -> None:
71+
"""Exiting the scope of the context manager"""
72+
_ffi_api.ProfilerExitScope(self) # type: ignore # pylint: disable=no-member

src/meta_schedule/profiler.cc

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "./utils.h"
20+
21+
namespace tvm {
22+
namespace meta_schedule {
23+
24+
/**************** Context Manager ****************/
25+
26+
class ProfilerInternal {
27+
public:
28+
static void EnterScope(Profiler ctx) { ctx.EnterWithScope(); }
29+
static void ExitScope(Profiler ctx) { ctx.ExitWithScope(); }
30+
};
31+
32+
void Profiler::EnterWithScope() {
33+
Optional<Profiler>& ctx = ProfilerThreadLocalStore::Get()->ctx;
34+
CHECK(!ctx.defined()) << "ValueError: Nested Profiler context managers are not allowed";
35+
ctx = *this;
36+
}
37+
38+
void Profiler::ExitWithScope() {
39+
Optional<Profiler>& ctx = ProfilerThreadLocalStore::Get()->ctx;
40+
ICHECK(ctx.defined());
41+
ctx = NullOpt;
42+
}
43+
44+
/**************** Profiler ****************/
45+
46+
Profiler::Profiler() {
47+
ObjectPtr<ProfilerNode> n = make_object<ProfilerNode>();
48+
data_ = n;
49+
}
50+
51+
ScopedTimer ProfilerNode::TimeScope(String name) {
52+
return ScopedTimer([name, tick = std::chrono::high_resolution_clock::now()]() -> void {
53+
Optional<Profiler> profiler = ProfilerThreadLocalStore::Get()->ctx;
54+
if (profiler.defined()) {
55+
Map<String, FloatImm>& stats = profiler.value()->stats;
56+
double duration = std::chrono::duration_cast<std::chrono::nanoseconds>(
57+
std::chrono::high_resolution_clock::now() - tick)
58+
.count() /
59+
1e9 / 60;
60+
if (stats.find(name) != stats.end()) {
61+
stats.Set(name, FloatImm(DataType::Float(64), stats.at(name)->value + duration));
62+
} else {
63+
stats.Set(name, FloatImm(DataType::Float(64), duration));
64+
}
65+
}
66+
});
67+
}
68+
69+
void ProfilerNode::StartContextTimer(String name) {
70+
stack.push_back(std::make_pair(name, std::chrono::high_resolution_clock::now()));
71+
}
72+
73+
void ProfilerNode::EndContextTimer() {
74+
ICHECK(stack.size() > 0) << "There is no timer context running!";
75+
String name = stack.back().first;
76+
double duration = std::chrono::duration_cast<std::chrono::nanoseconds>(
77+
std::chrono::high_resolution_clock::now() - stack.back().second)
78+
.count() /
79+
1e9 / 60;
80+
if (stats.find(name) != stats.end()) {
81+
stats.Set(name, FloatImm(DataType::Float(64), stats.at(name)->value + duration));
82+
} else {
83+
stats.Set(name, FloatImm(DataType::Float(64), duration));
84+
}
85+
stack.pop_back();
86+
}
87+
88+
TVM_REGISTER_NODE_TYPE(ProfilerNode);
89+
TVM_REGISTER_GLOBAL("meta_schedule.Profiler").set_body_typed([]() -> Profiler {
90+
return Profiler();
91+
});
92+
TVM_REGISTER_GLOBAL("meta_schedule.ProfilerEnterScope")
93+
.set_body_typed(ProfilerInternal::EnterScope);
94+
TVM_REGISTER_GLOBAL("meta_schedule.ProfilerExitScope").set_body_typed(ProfilerInternal::ExitScope);
95+
TVM_REGISTER_GLOBAL("meta_schedule.ProfilerStartContextTimer")
96+
.set_body_method<Profiler>(&ProfilerNode::StartContextTimer);
97+
TVM_REGISTER_GLOBAL("meta_schedule.ProfilerEndContextTimer")
98+
.set_body_method<Profiler>(&ProfilerNode::EndContextTimer);
99+
TVM_REGISTER_GLOBAL("meta_schedule.ProfilerGet").set_body_method<Profiler>(&ProfilerNode::Get);
100+
101+
} // namespace meta_schedule
102+
} // namespace tvm

src/meta_schedule/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <tvm/meta_schedule/database.h>
2929
#include <tvm/meta_schedule/feature_extractor.h>
3030
#include <tvm/meta_schedule/measure_callback.h>
31+
#include <tvm/meta_schedule/profiler.h>
3132
#include <tvm/meta_schedule/runner.h>
3233
#include <tvm/meta_schedule/schedule_rule.h>
3334
#include <tvm/meta_schedule/search_strategy.h>

0 commit comments

Comments
 (0)