Skip to content

Commit 3a24e49

Browse files
jcf94merrymercy
authored andcommitted
Add python custom sketch rule (apache#21)
* Add custom sketch rule * Bug fix
1 parent cd0a516 commit 3a24e49

File tree

8 files changed

+230
-48
lines changed

8 files changed

+230
-48
lines changed

python/tvm/ansor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929

3030
# Shortcut
3131
from .compute_dag import ComputeDAG
32-
from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams, PreLoadMeasuredStatesCallback
32+
from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams, \
33+
PreLoadMeasuredStates, PreAddCustomRule
3334
from .auto_schedule import auto_schedule
3435
from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext
3536
from .cost_model import RandomModel

python/tvm/ansor/auto_schedule.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,16 @@ def __init__(self, dag, workload_key, target, target_host=None,
6767

6868
@tvm._ffi.register_object("ansor.SearchPolicy")
6969
class SearchPolicy(Object):
70+
""" The base search policy class
71+
"""
7072
def continue_search(self, task, num_measure, verbose, measurer):
7173
return _ffi_api.SearchPolicyContinueSearchOneRound(self, task, num_measure, verbose, measurer)
72-
74+
7375
def set_task(self, task):
74-
_ffi_api.SearchPolicySetTask(self, task);
76+
_ffi_api.SearchPolicySetTask(self, task)
7577

7678
def set_verbose(self, verbose):
77-
_ffi_api.SearchPolicySetVerbose(self, verbose);
79+
_ffi_api.SearchPolicySetVerbose(self, verbose)
7880

7981
def run_callbacks(self, callbacks):
8082
_ffi_api.SearchPolicyRunCallbacks(self, callbacks)
@@ -130,14 +132,39 @@ class SearchCallback(Object):
130132
pass
131133

132134

133-
@tvm._ffi.register_object("ansor.PreLoadMeasuredStatesCallback")
134-
class PreLoadMeasuredStatesCallback(SearchCallback):
135+
@tvm._ffi.register_object("ansor.PreLoadMeasuredStates")
136+
class PreLoadMeasuredStates(SearchCallback):
135137
""" A SearchCallback that used for search policy to load measured hash
136138
from the log file.
139+
140+
Parameters
141+
----------
142+
filename: Str
137143
"""
138144
def __init__(self, filename: str):
139145
self.__init_handle_by_constructor__(
140-
_ffi_api.PreLoadMeasuredStatesCallback, filename)
146+
_ffi_api.PreLoadMeasuredStates, filename)
147+
148+
149+
@tvm._ffi.register_object("ansor.PreAddCustomRule")
150+
class PreAddCustomRule(SearchCallback):
151+
"""
152+
A SearchCallback for MetaTileRewritePolicy that allowing users to add
153+
custom sketch rule.
154+
155+
Notice: This is an advanced feature, make sure you're clear how it
156+
works and this should only be used in MetaTileRewritePolicy.
157+
158+
Parameters
159+
----------
160+
meet_condition_func: Function
161+
A function with `(policy, state, stage_id) -> int`
162+
apply_func: Function
163+
A function with `(policy, state, stage_id) -> [[State, int], ...]`
164+
"""
165+
def __init__(self, meet_condition_func, apply_func):
166+
self.__init_handle_by_constructor__(
167+
_ffi_api.PreAddCustomRule, meet_condition_func, apply_func)
141168

142169

143170
@tvm._ffi.register_object("ansor.TuneOption")
@@ -159,8 +186,13 @@ class TuneOption(Object):
159186
runner: Runner
160187
Runner which runs the program and measure time costs
161188
measure_callbacks: List[MeasureCallback]
162-
Callback functions
189+
Callback functions called after each measure
190+
Candidates:
191+
- ansor.LogToFile
163192
pre_search_callbacks: List[SearchCallback]
193+
Callback functions called before the search process
194+
Candidates:
195+
- ansor.PreLoadMeasuredStates
164196
"""
165197
def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64,
166198
verbose=1, builder='local', runner='local', measure_callbacks=None,

scripts/tune_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def objective_func(costs):
157157
builder=builder,
158158
runner=runner,
159159
measure_callbacks=[ansor.LogToFile(log_file)],
160-
pre_search_callbacks=[ansor.PreLoadMeasuredStatesCallback(log_file)])
160+
pre_search_callbacks=[ansor.PreLoadMeasuredStates(log_file)])
161161

162162
if args.task_scheduler == 'no':
163163
# tune workloads one by one

src/ansor/search_policy/meta_tile_rewrite_policy.cc

Lines changed: 90 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
namespace tvm {
4242
namespace ansor {
4343

44-
TVM_REGISTER_OBJECT_TYPE(MetaTileRewritePolicyNode);
44+
TVM_REGISTER_NODE_TYPE(MetaTileRewritePolicyNode);
45+
TVM_REGISTER_OBJECT_TYPE(PreAddCustomRuleNode);
4546

4647
// All possible candidates for auto_unroll
4748
const std::vector<int> MetaTileRewritePolicyNode::auto_unroll_configs{0, 16, 64, 512, 1024};
@@ -241,7 +242,7 @@ void MetaTileRewritePolicyNode::SearchOneRound(std::vector<State>* best_states,
241242

242243
// Synthesize meta structure
243244
std::vector<State> meta_structures;
244-
SynthesizeMetaStructure(&meta_structures);
245+
GenerateMetaSketch(&meta_structures);
245246

246247
// PrintAllStates(meta_structures);
247248
// exit(0);
@@ -272,8 +273,8 @@ void MetaTileRewritePolicyNode::SearchOneRound(std::vector<State>* best_states,
272273
RandomSampleStates(init_population, &rand_gen_, num_random_states * 10, random_states);
273274
}
274275

275-
// The baseclass of derivation rules used in meta structure synthesis
276-
class StructureSynthesisRule {
276+
// The baseclass of derivation rules used in meta sketch generation
277+
class SketchGenerationRule {
277278
public:
278279
enum ConditionEnum {
279280
kPass, kApply, kApplyAndSkipRest
@@ -345,7 +346,7 @@ static inline bool ShouldAlwaysBeInlined(
345346
}
346347

347348
// The rule that inlines simple elementwise ops
348-
class RuleAlwaysInline : public StructureSynthesisRule {
349+
class RuleAlwaysInline : public SketchGenerationRule {
349350
public:
350351
ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy,
351352
const State& state, int stage_id) final {
@@ -362,7 +363,7 @@ class RuleAlwaysInline : public StructureSynthesisRule {
362363
};
363364

364365
// The rule that simply skip the current stage
365-
class RuleSkipStage : public StructureSynthesisRule {
366+
class RuleSkipStage : public SketchGenerationRule {
366367
public:
367368
ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy,
368369
const State& state, int stage_id) final {
@@ -387,7 +388,7 @@ class RuleSkipStage : public StructureSynthesisRule {
387388
};
388389

389390
// The rule that performs multi-level tiling
390-
class RuleMultiLevelTiling : public StructureSynthesisRule {
391+
class RuleMultiLevelTiling : public SketchGenerationRule {
391392
public:
392393
ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy,
393394
const State& state, int stage_id) final {
@@ -413,7 +414,7 @@ class RuleMultiLevelTiling : public StructureSynthesisRule {
413414
};
414415

415416
// The rule that performs multi-level tiling and fuses later consumers
416-
class RuleMultiLevelTilingWithFusion : public StructureSynthesisRule {
417+
class RuleMultiLevelTilingWithFusion : public SketchGenerationRule {
417418
public:
418419
ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy,
419420
const State& state, int stage_id) final {
@@ -482,7 +483,7 @@ class RuleMultiLevelTilingWithFusion : public StructureSynthesisRule {
482483
};
483484

484485
// The rule that adds a cache write stage
485-
class RuleAddCacheWrite : public StructureSynthesisRule {
486+
class RuleAddCacheWrite : public SketchGenerationRule {
486487
public:
487488
ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy,
488489
const State& state, int stage_id) final {
@@ -515,7 +516,7 @@ class RuleAddCacheWrite : public StructureSynthesisRule {
515516
// The rule that adds a cache read stage
516517
// Mainly used for GPU cooperative fetching
517518
// Currently only support 1 to 1 match cache read
518-
class RuleAddCacheRead : public StructureSynthesisRule {
519+
class RuleAddCacheRead : public SketchGenerationRule {
519520
public:
520521
ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy,
521522
const State& state, int stage_id) final {
@@ -546,7 +547,7 @@ class RuleAddCacheRead : public StructureSynthesisRule {
546547
};
547548

548549
// The rule that adds rfactor stage
549-
class RuleAddRfactor : public StructureSynthesisRule {
550+
class RuleAddRfactor : public SketchGenerationRule {
550551
public:
551552
ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy,
552553
const State& state, int stage_id) final {
@@ -610,7 +611,7 @@ class RuleAddRfactor : public StructureSynthesisRule {
610611
}
611612
};
612613

613-
void MetaTileRewritePolicyNode::SynthesizeMetaStructure(
614+
void MetaTileRewritePolicyNode::GenerateMetaSketch(
614615
std::vector<State>* out_states) {
615616
State init_state = cur_task_->compute_dag.GetInitState();
616617
std::string cpu_multi_level_tiling_structure =
@@ -634,18 +635,22 @@ void MetaTileRewritePolicyNode::SynthesizeMetaStructure(
634635
static RuleAddCacheWrite rule_add_cache_write_stage;
635636
static RuleAddCacheRead rule_add_cache_read_stage;
636637
static RuleAddRfactor rule_add_rfactor;
637-
// We may apply and skip the rest when processing some rules,
638-
// should take care of the rule vector order here
639-
static std::vector<StructureSynthesisRule*> all_rules {
640-
&rule_always_inline, &rule_add_cache_write_stage,
641-
&rule_multi_level_tiling_with_fusion, &rule_multi_level_tiling,
642-
&rule_add_rfactor, &rule_skip_stage
643-
};
644-
if (IS_GPU(cur_task_)) {
645-
// Try cache read first before cache write
646-
all_rules.insert(all_rules.begin() + 1, &rule_add_cache_read_stage);
638+
if (sketch_rules.empty()) {
639+
// We may apply and skip the rest when processing some rules,
640+
// should take care of the rule vector order here
641+
sketch_rules.push_back(&rule_always_inline);
642+
sketch_rules.push_back(&rule_add_cache_write_stage);
643+
sketch_rules.push_back(&rule_multi_level_tiling_with_fusion);
644+
sketch_rules.push_back(&rule_multi_level_tiling);
645+
sketch_rules.push_back(&rule_add_rfactor);
646+
sketch_rules.push_back(&rule_skip_stage);
647+
if (IS_GPU(cur_task_)) {
648+
// Try cache read first before cache write
649+
sketch_rules.insert(sketch_rules.begin() + 1, &rule_add_cache_read_stage);
650+
}
651+
// TODO(xian): Add a new rule to try combination of multi-level
652+
// tiling + rfactor
647653
}
648-
// TODO(xian): Add a new rule to try combination of multi-level tiling + rfactor
649654

650655
// Derivation rule based synthesizer
651656
while (!pnow->empty()) {
@@ -661,15 +666,15 @@ void MetaTileRewritePolicyNode::SynthesizeMetaStructure(
661666
}
662667

663668
// Try all derivation rules
664-
for (const auto& rule : all_rules) {
669+
for (const auto& rule : sketch_rules) {
665670
auto rule_check = rule->MeetCondition(this, state, stage_id);
666-
if (rule_check > StructureSynthesisRule::ConditionEnum::kPass) {
671+
if (rule_check > SketchGenerationRule::ConditionEnum::kPass) {
667672
for (const auto& pair : rule->Apply(this, state, stage_id)) {
668673
cur_stage_id_map[pair.first] = pair.second;
669674
pnext->push_back(pair.first);
670675
}
671676
// Skip the reset rules
672-
if (rule_check == StructureSynthesisRule::ConditionEnum::kApplyAndSkipRest) {
677+
if (rule_check == SketchGenerationRule::ConditionEnum::kApplyAndSkipRest) {
673678
break;
674679
}
675680
}
@@ -1444,12 +1449,71 @@ void MetaTileRewritePolicyNode::EvolutionarySearch(
14441449
<< std::fixed << std::setprecision(2) << duration << std::endl;
14451450
}
14461451

1452+
class RuleCustomSketch : public SketchGenerationRule {
1453+
public:
1454+
RuleCustomSketch(PackedFunc meet_condition_func, PackedFunc apply_func) :
1455+
meet_condition_func_(meet_condition_func), apply_func_(apply_func) {}
1456+
1457+
inline ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy,
1458+
const State& state, int stage_id) final {
1459+
auto ret = meet_condition_func_(
1460+
tvm::runtime::GetRef<MetaTileRewritePolicy>(policy), state, stage_id);
1461+
if (ret.type_code() == 0) {
1462+
return ConditionEnum(static_cast<int>(ret));
1463+
} else {
1464+
return kApplyAndSkipRest;
1465+
}
1466+
}
1467+
1468+
inline std::vector<std::pair<State, int> > Apply(
1469+
const MetaTileRewritePolicyNode* policy,
1470+
const State& state, int stage_id) final {
1471+
std::vector<std::pair<State, int> > ret;
1472+
1473+
Array<Array<ObjectRef>> apply_ret = apply_func_(
1474+
tvm::runtime::GetRef<MetaTileRewritePolicy>(policy), state, stage_id);
1475+
1476+
for (const auto& item : apply_ret) {
1477+
CHECK_EQ(item.size(), 2);
1478+
State state = Downcast<State>(item[0]);
1479+
auto next = item[1].as<IntImmNode>();
1480+
ret.emplace_back(state, next->value);
1481+
}
1482+
return ret;
1483+
}
1484+
1485+
private:
1486+
PackedFunc meet_condition_func_;
1487+
PackedFunc apply_func_;
1488+
};
1489+
1490+
SearchCallback PreAddCustomRuleNode::make(PackedFunc meet_condition_func,
1491+
PackedFunc apply_func) {
1492+
auto node = make_object<PreAddCustomRuleNode>();
1493+
node->meet_condition_func = meet_condition_func;
1494+
node->apply_func = apply_func;
1495+
return SearchCallback(node);
1496+
}
1497+
1498+
void PreAddCustomRuleNode::callback(SearchPolicyNode* policy) {
1499+
CHECK(policy->IsInstance<MetaTileRewritePolicyNode>());
1500+
auto meta_policy = dynamic_cast<MetaTileRewritePolicyNode*>(policy);
1501+
meta_policy->sketch_rules.emplace_back(
1502+
new RuleCustomSketch(meet_condition_func, apply_func));
1503+
StdCout(policy->verbose_) << "Custom sketch rule added." << std::endl;
1504+
}
1505+
14471506
TVM_REGISTER_GLOBAL("ansor.MetaTileRewritePolicy")
14481507
.set_body_typed([](CostModel program_cost_model,
14491508
Map<String, ObjectRef> params,
14501509
int seed){
14511510
return MetaTileRewritePolicyNode::make(program_cost_model, params, seed);
14521511
});
14531512

1513+
TVM_REGISTER_GLOBAL("ansor.PreAddCustomRule")
1514+
.set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func) {
1515+
return PreAddCustomRuleNode::make(meet_condition_func, apply_func);
1516+
});
1517+
14541518
} // namespace ansor
14551519
} // namespace tvm

src/ansor/search_policy/meta_tile_rewrite_policy.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
namespace tvm {
3939
namespace ansor {
4040

41+
class SketchGenerationRule;
42+
4143
/*! Multi stage search policy */
4244
class MetaTileRewritePolicyNode: public SearchPolicyNode {
4345
public:
@@ -54,6 +56,7 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode {
5456
* str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU
5557
*/
5658
Map<String, ObjectRef> params;
59+
std::vector<SketchGenerationRule*> sketch_rules;
5760

5861
static SearchPolicy make(CostModel program_cost_model,
5962
Map<String, ObjectRef> params,
@@ -87,7 +90,7 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode {
8790
int num_random_states, std::vector<State>* random_states);
8891

8992
// Synthesize meta tiling structure without tile size
90-
void SynthesizeMetaStructure(std::vector<State>* out_states);
93+
void GenerateMetaSketch(std::vector<State>* out_states);
9194

9295
// Sample init population
9396
void SampleInitPopulation(const std::vector<State>& meta_structures,
@@ -107,6 +110,22 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode {
107110
// The throughputs of already measured states
108111
std::vector<float> measured_states_throughputs_;
109112
};
113+
TVM_DEFINE_MUTABLE_OBJECT_REF(MetaTileRewritePolicy, MetaTileRewritePolicyNode);
114+
115+
class PreAddCustomRuleNode : public SearchCallbackNode {
116+
public:
117+
// TODO(jcf94): Use tvm::runtime::TypedPackedFunc?
118+
PackedFunc meet_condition_func;
119+
PackedFunc apply_func;
120+
121+
static SearchCallback make(PackedFunc meet_condition_func,
122+
PackedFunc apply_func);
123+
124+
void callback(SearchPolicyNode* policy) final;
125+
126+
static constexpr const char *_type_key = "ansor.PreAddCustomRule";
127+
TVM_DECLARE_FINAL_OBJECT_INFO(PreAddCustomRuleNode, SearchCallbackNode);
128+
};
110129

111130
} // namespace ansor
112131
} // namespace tvm

0 commit comments

Comments
 (0)