Skip to content

Commit 4284a47

Browse files
committed
introduce TileForIntrin
1 parent b87ef32 commit 4284a47

File tree

3 files changed

+15
-11
lines changed

3 files changed

+15
-11
lines changed

src/meta_schedule/schedule_rule/auto_tensorize.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,13 @@ Optional<LoopRV> TilingwithTensorIntrin(const tir::Schedule& sch, const tir::Blo
8787
return reorder_suffix[0];
8888
}
8989

90+
tir::BlockRV TileForIntrin(tir::Schedule sch, tir::BlockRV block, const std::string& intrin_name) {
91+
Optional<tir::LoopRV> tiled_loop_rv = TilingwithTensorIntrin(sch, block, intrin_name);
92+
ICHECK(tiled_loop_rv.defined());
93+
tir::BlockRV outer_block = sch->Blockize(tiled_loop_rv.value());
94+
sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize, String(intrin_name));
95+
return outer_block;
96+
}
97+
9098
} // namespace meta_schedule
9199
} // namespace tvm

src/meta_schedule/schedule_rule/auto_tensorize.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ namespace meta_schedule {
2626

2727
Optional<tir::LoopRV> TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
2828
const String& intrin_name);
29+
30+
tir::BlockRV TileForIntrin(tir::Schedule sch, tir::BlockRV block, const std::string& intrin_name);
31+
2932
} // namespace meta_schedule
3033
} // namespace tvm
3134

src/meta_schedule/schedule_rule/multi_level_tiling_vnni.cc

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,13 @@
2424
namespace tvm {
2525
namespace meta_schedule {
2626

27-
std::vector<State> TileForVNNI(State state) {
28-
const std::string intrin_name = "dot_16x4_vnni";
29-
Optional<tir::LoopRV> tiled_loop_rv =
30-
TilingwithTensorIntrin(state.sch, state.block_rv, intrin_name);
31-
ICHECK(tiled_loop_rv.defined());
32-
state.block_rv = state.sch->Blockize(tiled_loop_rv.value());
33-
state.sch->Annotate(state.block_rv, tir::attr::meta_schedule_auto_tensorize, String(intrin_name));
34-
return {state};
35-
}
36-
3727
class MultiLevelTilingVNNINode : public MultiLevelTilingNode {
3828
protected:
3929
virtual std::vector<State> ApplySubRules(std::vector<State> states) {
40-
states = SubRule(std::move(states), [&](State state) { return TileForVNNI(state); });
30+
states = SubRule(std::move(states), [&](State state) {
31+
state.block_rv = TileForIntrin(state.sch, state.block_rv, "dot_16x4_vnni");
32+
return std::vector<State>(1, state);
33+
});
4134
return MultiLevelTilingNode::ApplySubRules(states);
4235
}
4336

0 commit comments

Comments
 (0)