Skip to content

Commit ab8a106

Browse files
junrushaozxybazhspectrometerHBHMasterJH5574jinhongyii
authored
[MetaSchedule] Add Per-Store-Feature (#9860)
* [MetaSchedule] Add Per-Store-Feature Co-authored-by: Xiyou Zhou <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> * fix lint * fix lint * Update per_store_feature.py * address comments * fix lint Co-authored-by: Xiyou Zhou <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Siyuan Feng <[email protected]>
1 parent 38f0239 commit ab8a106

File tree

10 files changed

+3034
-16
lines changed

10 files changed

+3034
-16
lines changed

include/tvm/tir/stmt.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1251,7 +1251,7 @@ constexpr const char* extern_scope = "extern_scope";
12511251
* This can hint some code generator to create a new function for compute.
12521252
*/
12531253
constexpr const char* compute_scope = "compute_scope";
1254-
/*! \brief Mark storage alignement requirement of buffers */
1254+
/*! \brief Mark storage alignment requirement of buffers */
12551255
constexpr const char* storage_alignment = "storage_alignment";
12561256
/*! \brief Mark storage scope of realization */
12571257
constexpr const char* realize_scope = "realize_scope";
@@ -1263,6 +1263,10 @@ constexpr const char* device_type = "device_type";
12631263
constexpr const char* loop_scope = "loop_scope";
12641264
/*! \brief Mark of reduce scope */
12651265
constexpr const char* reduce_scope = "reduce_scope";
1266+
/*! \brief Pragma: auto-unroll, max_step */
1267+
constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step";
1268+
/*! \brief Pragma: unroll explicit */
1269+
constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";
12661270
/*! \brief Mark region is guarded by the pragma extension */
12671271
constexpr const char* pragma_scope_prefix = "pragma_";
12681272
/*! \brief Import C source or file into the final code gen module */

python/tvm/meta_schedule/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,6 @@
2323
from . import search_strategy
2424
from . import schedule_rule
2525
from . import integration
26+
from . import feature_extractor
2627
from .tune_context import TuneContext
28+
from .search_strategy import MeasureCandidate

python/tvm/meta_schedule/feature_extractor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@
2020
measure candidates for use in cost model.
2121
"""
2222
from .feature_extractor import FeatureExtractor, PyFeatureExtractor
23+
from .per_store_feature import PerStoreFeature
2324
from .random_feature_extractor import RandomFeatureExtractor
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""We extract one feature vector per BufferStoreNode statement in a TIR Stmt,
18+
so we call this feature as "per-store" feature.
19+
"""
20+
from tvm._ffi import register_object
21+
22+
from .. import _ffi_api
23+
from .feature_extractor import FeatureExtractor
24+
25+
26+
@register_object("meta_schedule.PerStoreFeature")
27+
class PerStoreFeature(FeatureExtractor):
28+
"""PerStoreFeature extracts one feature vector per BufferStoreNode
29+
30+
Parameters
31+
----------
32+
buffers_per_store : int
33+
The number of buffers in each BufferStore; Pad or truncate if necessary.
34+
arith_intensity_curve_num_samples : int
35+
The number of samples used in the arithmetic intensity curve.
36+
cache_line_bytes : int
37+
The number of bytes in a cache line.
38+
"""
39+
40+
buffers_per_store: int
41+
"""The number of buffers in each BufferStore; Pad or truncate if necessary."""
42+
arith_intensity_curve_num_samples: int # pylint: disable=invalid-name
43+
"""The number of samples used in the arithmetic intensity curve."""
44+
cache_line_bytes: int
45+
"""The number of bytes in a cache line."""
46+
feature_vector_length: int
47+
"""Length of the feature vector."""
48+
49+
def __init__(
50+
self,
51+
buffers_per_store: int = 5,
52+
arith_intensity_curve_num_samples: int = 10,
53+
cache_line_bytes: int = 64,
54+
):
55+
self.__init_handle_by_constructor__(
56+
_ffi_api.FeatureExtractorPerStoreFeature, # type: ignore # pylint: disable=no-member
57+
buffers_per_store,
58+
arith_intensity_curve_num_samples,
59+
cache_line_bytes,
60+
)

0 commit comments

Comments
 (0)