@@ -364,6 +364,7 @@ std::vector<State> SubRule(std::vector<State> states, FLambda sub_rule) {
364364 */
365365class MultiLevelTilingNode : public ScheduleRuleNode {
366366 public:
367+ inline std::vector<State> TileForVNNI (State state) const ;
367368 // SubRule 1. add write cache
368369 inline std::vector<State> AddWriteReuse (State state) const ;
369370 // SubRule 2. tile the loop nest
@@ -390,7 +391,9 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
390391 }
391392 sch->Annotate (block_rv, tir::attr::meta_schedule_tiling_structure, structure);
392393
394+ LOG (INFO) << " Doing multi level tiling" ;
393395 std::vector<State> states{State (sch, block_rv)};
396+ states = SubRule (std::move (states), [&](State state) { return TileForVNNI (state); });
394397 states = SubRule (std::move (states), [&](State state) { return TileLoopNest (state); });
395398 states = SubRule (std::move (states), [&](State state) { return AddWriteReuse (state); });
396399 states = SubRule (std::move (states), [&](State state) { return AddReadReuse (state); });
@@ -444,6 +447,19 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
444447 TVM_DECLARE_FINAL_OBJECT_INFO (MultiLevelTilingNode, ScheduleRuleNode);
445448};
446449
450+ inline std::vector<State> MultiLevelTilingNode::TileForVNNI (State state) const {
451+ std::vector<State> result;
452+ BlockRV block_rv = state.block_rv ;
453+ const std::string intrin_name = " dot_16x1x16_uint8_int8_int32_cascadelake" ;
454+ Optional<LoopRV> tiled_loop_rv = TilingwithTensorIntrin (state.sch , block_rv, intrin_name);
455+ ICHECK (tiled_loop_rv.defined ());
456+ LOG (INFO) << " After TilingwithTensorIntrin" << state.sch ->mod ();
457+ state.block_rv = state.sch ->Blockize (tiled_loop_rv.value ());
458+ state.sch ->Annotate (block_rv, tir::attr::meta_schedule_auto_tensorize, String (intrin_name));
459+ result.push_back (state);
460+ return result;
461+ }
462+
447463inline std::vector<State> MultiLevelTilingNode::AddWriteReuse (State state) const {
448464 const ReuseConfig& config = this ->reuse_write_ ;
449465 if (config.req == ReuseType::kNoReuse ) {
@@ -503,6 +519,7 @@ inline std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const
503519 // Step 2. For each loop axis, tile it
504520 int64_t spatial_loop_product = 1 ;
505521 std::vector<Array<LoopRV>> tiles (s_indices_.size () + r_indices_.size ());
522+ LOG (INFO) << " Tile loops: " << loops.size ();
506523 for (int i = 0 , n = loops.size (); i < n; ++i) {
507524 LoopRV loop = loops[i];
508525 const std::vector<int >* idx = nullptr ;
0 commit comments