4141namespace  tvm  {
4242namespace  ansor  {
4343
44- TVM_REGISTER_OBJECT_TYPE (MetaTileRewritePolicyNode);
44+ TVM_REGISTER_NODE_TYPE (MetaTileRewritePolicyNode);
45+ TVM_REGISTER_OBJECT_TYPE (PreAddCustomRuleNode);
4546
4647//  All possible candidates for auto_unroll
4748const  std::vector<int > MetaTileRewritePolicyNode::auto_unroll_configs{0 , 16 , 64 , 512 , 1024 };
@@ -241,7 +242,7 @@ void MetaTileRewritePolicyNode::SearchOneRound(std::vector<State>* best_states,
241242
242243  //  Synthesize meta structure
243244  std::vector<State> meta_structures;
244-   SynthesizeMetaStructure (&meta_structures);
245+   GenerateMetaSketch (&meta_structures);
245246
246247  //  PrintAllStates(meta_structures);
247248  //  exit(0);
@@ -272,8 +273,8 @@ void MetaTileRewritePolicyNode::SearchOneRound(std::vector<State>* best_states,
272273  RandomSampleStates (init_population, &rand_gen_, num_random_states * 10 , random_states);
273274}
274275
275- //  The baseclass of derivation rules used in meta structure synthesis 
276- class  StructureSynthesisRule  {
276+ //  The baseclass of derivation rules used in meta sketch generation 
277+ class  SketchGenerationRule  {
277278 public: 
278279  enum  ConditionEnum {
279280    kPass , kApply , kApplyAndSkipRest 
@@ -345,7 +346,7 @@ static inline bool ShouldAlwaysBeInlined(
345346}
346347
347348//  The rule that inlines simple elementwise ops
348- class  RuleAlwaysInline  : public  StructureSynthesisRule  {
349+ class  RuleAlwaysInline  : public  SketchGenerationRule  {
349350 public: 
350351  ConditionEnum MeetCondition (const  MetaTileRewritePolicyNode* policy,
351352      const  State& state, int  stage_id) final  {
@@ -362,7 +363,7 @@ class RuleAlwaysInline : public StructureSynthesisRule {
362363};
363364
364365//  The rule that simply skip the current stage
365- class  RuleSkipStage  : public  StructureSynthesisRule  {
366+ class  RuleSkipStage  : public  SketchGenerationRule  {
366367 public: 
367368  ConditionEnum MeetCondition (const  MetaTileRewritePolicyNode* policy,
368369                              const  State& state, int  stage_id) final  {
@@ -387,7 +388,7 @@ class RuleSkipStage : public StructureSynthesisRule {
387388};
388389
389390//  The rule that performs multi-level tiling
390- class  RuleMultiLevelTiling  : public  StructureSynthesisRule  {
391+ class  RuleMultiLevelTiling  : public  SketchGenerationRule  {
391392 public: 
392393  ConditionEnum MeetCondition (const  MetaTileRewritePolicyNode* policy,
393394                              const  State& state, int  stage_id) final  {
@@ -413,7 +414,7 @@ class RuleMultiLevelTiling : public StructureSynthesisRule {
413414};
414415
415416//  The rule that performs multi-level tiling and fuses later consumers
416- class  RuleMultiLevelTilingWithFusion  : public  StructureSynthesisRule  {
417+ class  RuleMultiLevelTilingWithFusion  : public  SketchGenerationRule  {
417418 public: 
418419  ConditionEnum MeetCondition (const  MetaTileRewritePolicyNode* policy,
419420                              const  State& state, int  stage_id) final  {
@@ -482,7 +483,7 @@ class RuleMultiLevelTilingWithFusion : public StructureSynthesisRule {
482483};
483484
484485//  The rule that adds a cache write stage
485- class  RuleAddCacheWrite  : public  StructureSynthesisRule  {
486+ class  RuleAddCacheWrite  : public  SketchGenerationRule  {
486487 public: 
487488  ConditionEnum MeetCondition (const  MetaTileRewritePolicyNode* policy,
488489                              const  State& state, int  stage_id) final  {
@@ -515,7 +516,7 @@ class RuleAddCacheWrite : public StructureSynthesisRule {
515516//  The rule that adds a cache read stage
516517//  Mainly used for GPU cooperative fetching
517518//  Currently only support 1 to 1 match cache read
518- class  RuleAddCacheRead  : public  StructureSynthesisRule  {
519+ class  RuleAddCacheRead  : public  SketchGenerationRule  {
519520 public: 
520521  ConditionEnum MeetCondition (const  MetaTileRewritePolicyNode* policy,
521522                              const  State& state, int  stage_id) final  {
@@ -546,7 +547,7 @@ class RuleAddCacheRead : public StructureSynthesisRule {
546547};
547548
548549//  The rule that adds rfactor stage
549- class  RuleAddRfactor  : public  StructureSynthesisRule  {
550+ class  RuleAddRfactor  : public  SketchGenerationRule  {
550551 public: 
551552  ConditionEnum MeetCondition (const  MetaTileRewritePolicyNode* policy,
552553                              const  State& state, int  stage_id) final  {
@@ -610,7 +611,7 @@ class RuleAddRfactor : public StructureSynthesisRule {
610611  }
611612};
612613
613- void  MetaTileRewritePolicyNode::SynthesizeMetaStructure  (
614+ void  MetaTileRewritePolicyNode::GenerateMetaSketch  (
614615    std::vector<State>* out_states) {
615616  State init_state = cur_task_->compute_dag .GetInitState ();
616617  std::string cpu_multi_level_tiling_structure =
@@ -634,18 +635,22 @@ void MetaTileRewritePolicyNode::SynthesizeMetaStructure(
634635  static  RuleAddCacheWrite rule_add_cache_write_stage;
635636  static  RuleAddCacheRead rule_add_cache_read_stage;
636637  static  RuleAddRfactor rule_add_rfactor;
637-   //  We may apply and skip the rest when processing some rules,
638-   //  should take care of the rule vector order here
639-   static  std::vector<StructureSynthesisRule*> all_rules {
640-     &rule_always_inline, &rule_add_cache_write_stage,
641-     &rule_multi_level_tiling_with_fusion, &rule_multi_level_tiling,
642-     &rule_add_rfactor, &rule_skip_stage
643-   };
644-   if  (IS_GPU (cur_task_)) {
645-     //  Try cache read first before cache write
646-     all_rules.insert (all_rules.begin () + 1 , &rule_add_cache_read_stage);
638+   if  (sketch_rules.empty ()) {
639+     //  We may apply and skip the rest when processing some rules,
640+     //  should take care of the rule vector order here
641+     sketch_rules.push_back (&rule_always_inline);
642+     sketch_rules.push_back (&rule_add_cache_write_stage);
643+     sketch_rules.push_back (&rule_multi_level_tiling_with_fusion);
644+     sketch_rules.push_back (&rule_multi_level_tiling);
645+     sketch_rules.push_back (&rule_add_rfactor);
646+     sketch_rules.push_back (&rule_skip_stage);
647+     if  (IS_GPU (cur_task_)) {
648+       //  Try cache read first before cache write
649+       sketch_rules.insert (sketch_rules.begin () + 1 , &rule_add_cache_read_stage);
650+     }
651+     //  TODO(xian): Add a new rule to try combination of multi-level
652+     //  tiling + rfactor
647653  }
648-   //  TODO(xian): Add a new rule to try combination of multi-level tiling + rfactor
649654
650655  //  Derivation rule based synthesizer
651656  while  (!pnow->empty ()) {
@@ -661,15 +666,15 @@ void MetaTileRewritePolicyNode::SynthesizeMetaStructure(
661666      }
662667
663668      //  Try all derivation rules
664-       for  (const  auto & rule : all_rules ) {
669+       for  (const  auto & rule : sketch_rules ) {
665670        auto  rule_check = rule->MeetCondition (this , state, stage_id);
666-         if  (rule_check > StructureSynthesisRule ::ConditionEnum::kPass ) {
671+         if  (rule_check > SketchGenerationRule ::ConditionEnum::kPass ) {
667672          for  (const  auto & pair : rule->Apply (this , state, stage_id)) {
668673            cur_stage_id_map[pair.first ] = pair.second ;
669674            pnext->push_back (pair.first );
670675          }
671676          //  Skip the reset rules
672-           if  (rule_check == StructureSynthesisRule ::ConditionEnum::kApplyAndSkipRest ) {
677+           if  (rule_check == SketchGenerationRule ::ConditionEnum::kApplyAndSkipRest ) {
673678            break ;
674679          }
675680        }
@@ -1444,12 +1449,71 @@ void MetaTileRewritePolicyNode::EvolutionarySearch(
14441449                    << std::fixed << std::setprecision (2 ) << duration << std::endl;
14451450}
14461451
1452+ class  RuleCustomSketch  : public  SketchGenerationRule  {
1453+  public: 
1454+   RuleCustomSketch (PackedFunc meet_condition_func, PackedFunc apply_func) :
1455+       meet_condition_func_ (meet_condition_func), apply_func_(apply_func) {}
1456+ 
1457+   inline  ConditionEnum MeetCondition (const  MetaTileRewritePolicyNode* policy,
1458+                                      const  State& state, int  stage_id) final  {
1459+     auto  ret = meet_condition_func_ (
1460+         tvm::runtime::GetRef<MetaTileRewritePolicy>(policy), state, stage_id);
1461+     if  (ret.type_code () == 0 ) {
1462+       return  ConditionEnum (static_cast <int >(ret));
1463+     } else  {
1464+       return  kApplyAndSkipRest ;
1465+     }
1466+   }
1467+ 
1468+   inline  std::vector<std::pair<State, int > > Apply (
1469+       const  MetaTileRewritePolicyNode* policy,
1470+       const  State& state, int  stage_id) final  {
1471+     std::vector<std::pair<State, int > > ret;
1472+ 
1473+     Array<Array<ObjectRef>> apply_ret = apply_func_ (
1474+         tvm::runtime::GetRef<MetaTileRewritePolicy>(policy), state, stage_id);
1475+ 
1476+     for  (const  auto & item : apply_ret) {
1477+       CHECK_EQ (item.size (), 2 );
1478+       State state = Downcast<State>(item[0 ]);
1479+       auto  next = item[1 ].as <IntImmNode>();
1480+       ret.emplace_back (state, next->value );
1481+     }
1482+     return  ret;
1483+   }
1484+ 
1485+  private: 
1486+   PackedFunc meet_condition_func_;
1487+   PackedFunc apply_func_;
1488+ };
1489+ 
1490+ SearchCallback PreAddCustomRuleNode::make (PackedFunc meet_condition_func,
1491+                                           PackedFunc apply_func) {
1492+   auto  node = make_object<PreAddCustomRuleNode>();
1493+   node->meet_condition_func  = meet_condition_func;
1494+   node->apply_func  = apply_func;
1495+   return  SearchCallback (node);
1496+ }
1497+ 
1498+ void  PreAddCustomRuleNode::callback (SearchPolicyNode* policy) {
1499+   CHECK (policy->IsInstance <MetaTileRewritePolicyNode>());
1500+   auto  meta_policy = dynamic_cast <MetaTileRewritePolicyNode*>(policy);
1501+   meta_policy->sketch_rules .emplace_back (
1502+       new  RuleCustomSketch (meet_condition_func, apply_func));
1503+   StdCout (policy->verbose_ ) << " Custom sketch rule added."   << std::endl;
1504+ }
1505+ 
14471506TVM_REGISTER_GLOBAL (" ansor.MetaTileRewritePolicy"  )
14481507.set_body_typed([](CostModel program_cost_model,
14491508                   Map<String, ObjectRef> params,
14501509                   int  seed){
14511510  return  MetaTileRewritePolicyNode::make (program_cost_model, params, seed);
14521511});
14531512
1513+ TVM_REGISTER_GLOBAL (" ansor.PreAddCustomRule"  )
1514+ .set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func) {
1515+   return  PreAddCustomRuleNode::make (meet_condition_func, apply_func);
1516+ });
1517+ 
14541518}  //  namespace ansor
14551519}  //  namespace tvm
0 commit comments