Skip to content

Commit e8ec543

Browse files
Siyuan FengjunrushaozxybazhspectrometerHBHMasterJH5574
committed
[MetaSchedule] random compute location
Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Xiyou Zhou <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]>
1 parent 7485413 commit e8ec543

File tree

19 files changed

+697
-0
lines changed

19 files changed

+697
-0
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,14 @@ class ScheduleNode : public runtime::Object {
210210
*/
211211
virtual Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor,
212212
Optional<Array<Integer>> decision = NullOpt) = 0;
213+
/*!
214+
* \brief Sample a compute-at location of the given block
215+
* \param block_rv The block whose compute-at location is to be sampled
216+
* \param decision The sampling decision
217+
* \return The sampled loop where the input block is to be computed at
218+
*/
219+
virtual LoopRV SampleComputeLocation(const BlockRV& block_rv,
220+
Optional<Integer> decision = NullOpt) = 0;
213221

214222
/******** Schedule: Get blocks & loops ********/
215223
/*!

include/tvm/tir/stmt.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,6 +1361,13 @@ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_
13611361
*/
13621362
constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
13631363

1364+
/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */
1365+
constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";
1366+
1367+
/*! \brief Mark the block whose producer needs to be applied by rule Random-Compute-Location */
1368+
constexpr const char* meta_schedule_random_compute_producer =
1369+
"meta_schedule.random_compute_producer";
1370+
13641371
/*!
13651372
* \brief Check if attr_key is a pragma key extension
13661373
* \param attr_key The attr key to be compared

python/tvm/meta_schedule/schedule_rule/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,6 @@
1616
Meta Schedule schedule rules are used for modification of
1717
blocks in a schedule. See also PostOrderApply.
1818
"""
19+
1920
from .schedule_rule import PyScheduleRule, ScheduleRule
21+
from .random_compute_location import RandomComputeLocation
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Rule that randomly select a compute-at location for a free block"""
18+
from tvm._ffi import register_object
19+
20+
from .. import _ffi_api
21+
from .schedule_rule import ScheduleRule
22+
23+
24+
@register_object("meta_schedule.RandomComputeLocation")
25+
class RandomComputeLocation(ScheduleRule):
26+
"""A rule that randomly select a compute-at location for a free block"""
27+
28+
def __init__(self) -> None:
29+
self.__init_handle_by_constructor__(
30+
_ffi_api.ScheduleRuleRandomComputeLocation, # type: ignore # pylint: disable=no-member
31+
)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
18+
from typing import List
19+
20+
from tvm.tir import Schedule
21+
from tvm.tir.schedule import Trace
22+
23+
24+
def check_trace(spaces: List[Schedule], expected: List[List[str]]):
25+
expected_traces = {"\n".join(t) for t in expected}
26+
actual_traces = set()
27+
for space in spaces:
28+
trace = Trace(space.trace.insts, {})
29+
trace = trace.simplified(remove_postproc=True)
30+
str_trace = "\n".join(str(trace).strip().splitlines())
31+
actual_traces.add(str_trace)
32+
assert str_trace in expected_traces, "\n" + str_trace
33+
assert len(expected_traces) == len(actual_traces)

python/tvm/tir/schedule/schedule.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,32 @@ def sample_perfect_tile(
369369
)
370370
)
371371

372+
@type_checked
373+
def sample_compute_location(
374+
self,
375+
block: BlockRV,
376+
decision: Optional[int] = None,
377+
) -> LoopRV:
378+
"""Sample a compute-at location of the given block
379+
380+
Parameters
381+
----------
382+
block : BlockRV
383+
The block whose compute-at location is to be sampled
384+
decision : Optional[int]
385+
The sampling decision
386+
387+
Returns
388+
-------
389+
result : LoopRV
390+
The sampled loop where the input block is to be computed at
391+
"""
392+
return _ffi_api.ScheduleSampleComputeLocation( # type: ignore # pylint: disable=no-member
393+
self,
394+
block,
395+
decision,
396+
)
397+
372398
########## Schedule: Get blocks & loops ##########
373399
@type_checked
374400
def get_block(
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "../utils.h"
20+
21+
namespace tvm {
22+
namespace meta_schedule {
23+
class RandomComputeLocationNode : public ScheduleRuleNode {
24+
public:
25+
// Inherited from ScheduleRuleNode
26+
void InitializeWithTuneContext(const TuneContext& context) final {}
27+
28+
// Inherited from ScheduleRuleNode
29+
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
30+
if (!CheckConditions(sch, block_rv)) {
31+
return {sch};
32+
}
33+
34+
// Step 1. If the producer of the input block needs a random compute-at location (specified by
35+
// the annotation), we colect the producer first, and transform the producer block later.
36+
// - The reason we collect the producer before transforming the input block is that, if the
37+
// decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer
38+
// access the input block. Hence we collect its producer ahead of time.
39+
// - Note that only single producer is allowed in this case.
40+
Array<tir::BlockRV> producers{nullptr};
41+
if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer,
42+
true)) {
43+
producers = sch->GetProducers(block_rv);
44+
sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer);
45+
ICHECK_EQ(producers.size(), 1);
46+
}
47+
48+
// Step 2. Transform the input block.
49+
tir::Schedule res = RandomlyComputeAt(sch, block_rv);
50+
51+
// Step 3. Transform the producer block if compute-location sampling is needed.
52+
if (producers.defined()) {
53+
res = RandomlyComputeAt(res, producers[0]);
54+
}
55+
56+
return {res};
57+
}
58+
59+
private:
60+
bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const {
61+
const tir::StmtSRef& block_sref = sch->GetSRef(block_rv);
62+
const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
63+
64+
// Cond 1. The block is not the root block.
65+
if (block_sref->parent == nullptr) {
66+
return false;
67+
}
68+
// Cond 2. The block should be the direct child block of the root block.
69+
if (GetScopeRoot(sch->state(), block_sref, //
70+
/*require_stage_pipeline=*/false, //
71+
/*require_subtree_compact_dataflow=*/false)
72+
->parent != nullptr) {
73+
return false;
74+
}
75+
// Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child
76+
// block.
77+
Array<tir::StmtSRef> loop_srefs = tir::GetLoops(block_sref);
78+
if (loop_srefs.empty()) {
79+
return false;
80+
}
81+
if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) {
82+
return false;
83+
}
84+
// Cond 5. The block is not tiled. We check this condition by examine the block's annotation.
85+
if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_tiling_structure).defined()) {
86+
return false;
87+
}
88+
// Cond 6. The block has at lease one consumer.
89+
if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) {
90+
return false;
91+
}
92+
return true;
93+
}
94+
95+
/*!
96+
* \brief Keep sampling a compute-at location for the input block until success.
97+
* \param sch The TIR schedule
98+
* \param block_rv The block whose compute-at location is to be sampled
99+
* \return The TIR schedule after transformation
100+
*/
101+
tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::BlockRV& block_rv) {
102+
for (;;) {
103+
tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv);
104+
try {
105+
sch->ComputeAt(block_rv, compute_at_loc, true);
106+
} catch (const dmlc::Error& e) {
107+
// ComputeAt fails, cleanup the following before re-try:
108+
// 1) trace: instruction & decisions
109+
// 2) sym_tab
110+
sch->trace().value()->Pop();
111+
sch->RemoveRV(compute_at_loc);
112+
continue;
113+
}
114+
break;
115+
}
116+
return sch;
117+
}
118+
119+
public:
120+
void VisitAttrs(tvm::AttrVisitor* v) {}
121+
122+
static constexpr const char* _type_key = "meta_schedule.RandomComputeLocation";
123+
TVM_DECLARE_FINAL_OBJECT_INFO(RandomComputeLocationNode, ScheduleRuleNode);
124+
};
125+
126+
ScheduleRule ScheduleRule::RandomComputeLocation() {
127+
ObjectPtr<RandomComputeLocationNode> n = make_object<RandomComputeLocationNode>();
128+
return ScheduleRule(n);
129+
}
130+
131+
TVM_REGISTER_NODE_TYPE(RandomComputeLocationNode);
132+
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleRandomComputeLocation")
133+
.set_body_typed(ScheduleRule::RandomComputeLocation);
134+
} // namespace meta_schedule
135+
} // namespace tvm

src/tir/schedule/analysis.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,31 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self
266266
*/
267267
BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref);
268268

269+
/*!
270+
* \brief Get the IterVarType of the specific loop, according to the blocks it's bound to
271+
* \param loop_sref The loop to be checked
272+
* \return The IterVarType of the specific loop
273+
*/
274+
IterVarType GetLoopIterType(const StmtSRef& loop_sref);
275+
276+
/*!
277+
* \brief Get the lowest common ancestor of an array of blocks or loops on the sref tree
278+
* \param srefs The block srefs or loop srefs whose lowest common ancestor is to be queried
279+
* \return The lowest common ancestor of the input block srefs or loop srefs
280+
* \note The input array is required to have at least one sref
281+
*/
282+
StmtSRef GetSRefLowestCommonAncestor(const Array<StmtSRef>& srefs);
283+
284+
/*!
285+
* \brief Collect all the feasible compute-at locations of the input block
286+
* \param self The schedule state
287+
* \param block_sref The block whose compute-at locations are to be collected
288+
* \return All the feasible compute-at locations of the input block, given as an array of loop srefs
289+
* and an array of their indices among the outer loops of the input block
290+
*/
291+
std::pair<Array<StmtSRef>, std::vector<int>> CollectComputeLocation(const ScheduleState& self,
292+
const StmtSRef& block_sref);
293+
269294
/******** Producer-consumer relation ********/
270295

271296
/*!

0 commit comments

Comments
 (0)