Skip to content

Commit 2b53437

Browse files
committed
TilingwithTensorIntrin works
1 parent 86baa31 commit 2b53437

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

src/meta_schedule/schedule_rule/multi_level_tiling.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ std::vector<State> SubRule(std::vector<State> states, FLambda sub_rule) {
364364
*/
365365
class MultiLevelTilingNode : public ScheduleRuleNode {
366366
public:
367+
inline std::vector<State> TileForVNNI(State state) const;
367368
// SubRule 1. add write cache
368369
inline std::vector<State> AddWriteReuse(State state) const;
369370
// SubRule 2. tile the loop nest
@@ -390,7 +391,9 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
390391
}
391392
sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure);
392393

394+
LOG(INFO) << "Doing multi level tiling";
393395
std::vector<State> states{State(sch, block_rv)};
396+
states = SubRule(std::move(states), [&](State state) { return TileForVNNI(state); });
394397
states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); });
395398
states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); });
396399
states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); });
@@ -444,6 +447,19 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
444447
TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode);
445448
};
446449

450+
inline std::vector<State> MultiLevelTilingNode::TileForVNNI(State state) const {
451+
std::vector<State> result;
452+
BlockRV block_rv = state.block_rv;
453+
const std::string intrin_name = "dot_16x1x16_uint8_int8_int32_cascadelake";
454+
Optional<LoopRV> tiled_loop_rv = TilingwithTensorIntrin(state.sch, block_rv, intrin_name);
455+
ICHECK(tiled_loop_rv.defined());
456+
LOG(INFO) << "After TilingwithTensorIntrin" << state.sch->mod();
457+
state.block_rv = state.sch->Blockize(tiled_loop_rv.value());
458+
state.sch->Annotate(block_rv, tir::attr::meta_schedule_auto_tensorize, String(intrin_name));
459+
result.push_back(state);
460+
return result;
461+
}
462+
447463
inline std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
448464
const ReuseConfig& config = this->reuse_write_;
449465
if (config.req == ReuseType::kNoReuse) {
@@ -503,6 +519,7 @@ inline std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const
503519
// Step 2. For each loop axis, tile it
504520
int64_t spatial_loop_product = 1;
505521
std::vector<Array<LoopRV>> tiles(s_indices_.size() + r_indices_.size());
522+
LOG(INFO) << "Tile loops: " << loops.size();
506523
for (int i = 0, n = loops.size(); i < n; ++i) {
507524
LoopRV loop = loops[i];
508525
const std::vector<int>* idx = nullptr;

0 commit comments

Comments
 (0)