Skip to content

Commit 40f5e25

Browse files
committed
support async pipeline in mlt
1 parent 202fead commit 40f5e25

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

src/meta_schedule/schedule_rule/multi_level_tiling.cc

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,21 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context)
8787
TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target";
8888
}
8989
}
90+
if (Optional<String> opt_sm = context->target.value()->GetAttr<String>("arch")) {
91+
std::string sm = opt_sm.value();
92+
if (support::StartsWith(sm, "sm_")) {
93+
sm = sm.substr(3);
94+
try {
95+
// only sm_80 or higher supports async memcopy
96+
if (std::stoi(sm) >= 80) {
97+
this->stages.insert(this->stages.end(), {4, 5});
98+
}
99+
} catch (const std::invalid_argument& e) {
100+
LOG(WARNING) << "ValueError: Unable to parse `target.arch`: " << sm
101+
<< ". Details: " << e.what();
102+
}
103+
}
104+
}
90105
logger = context->logger;
91106
}
92107

@@ -115,6 +130,9 @@ std::vector<State> MultiLevelTilingNode::ApplySubRules(std::vector<State> states
115130
states = SubRule(std::move(states), [&](State state) { return TileLoopNest(std::move(state)); });
116131
states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(std::move(state)); });
117132
states = SubRule(std::move(states), [&](State state) { return AddReadReuse(std::move(state)); });
133+
states = SubRule(std::move(states), [&](State state) {
134+
return AddAsyncPipeline(std::move(state));
135+
});
118136
return states;
119137
}
120138

@@ -280,6 +298,43 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
280298
return results;
281299
}
282300

301+
std::vector<State> MultiLevelTilingNode::AddAsyncPipeline(State state) const {
302+
// For arch that does not support async pipeline, this->stages will be an empty vector
303+
if (r_indices_.size() < 1 || this->stages.empty()) {
304+
return {state};
305+
}
306+
// Current only support default config used by ScheduleRule::DefaultCUDA
307+
// @see src/meta_schedule/schedule_rule/schedule_rule.cc
308+
// check the reduce loop contains exactly 3 for loops
309+
// therefore it matches the notation array size in the following code
310+
tir::StmtSRef r_loop_sref = state->sch->GetSRef(state->tiles[r_indices_[0]].back());
311+
const tir::ForNode* r_for_loop = TVM_SREF_TO_FOR(r_loop_sref);
312+
Array<tir::Stmt> seq = Downcast<tir::SeqStmt>(r_for_loop->body)->seq;
313+
if (seq.size() != 3) {
314+
return {state};
315+
}
316+
for (auto& stmt : seq) {
317+
if (!stmt.as<tir::ForNode>()) {
318+
return {state};
319+
}
320+
}
321+
322+
LoopRV r_loop_fused = state->sch->Fuse(state->tiles[r_indices_[0]]);
323+
std::vector<State> ret;
324+
ret.push_back(state);
325+
for (int stage : this->stages) {
326+
State new_state = state->Copy();
327+
new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_stage,
328+
Array<Integer>{0, 0, stage - 2});
329+
new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_order,
330+
Array<Integer>{0, 1, 2});
331+
new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_async_stages,
332+
Array<Integer>{0});
333+
ret.push_back(std::move(new_state));
334+
}
335+
return ret;
336+
}
337+
283338
void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch,
284339
const tir::BlockRV& block) const {
285340
// Filter out invalid vector lanes according to the data type.

src/meta_schedule/schedule_rule/multi_level_tiling.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
148148
std::vector<State> TileLoopNest(State state) const;
149149
// SubRule 3. add read cache
150150
std::vector<State> AddReadReuse(State state) const;
151+
// SubRule 4. add async pipeline
152+
std::vector<State> AddAsyncPipeline(State state) const;
151153

152154
// Do nothing; Inherited from ScheduleRuleNode
153155
void InitializeWithTuneContext(const TuneContext& context) final;
@@ -192,6 +194,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
192194
int thread_warp_size_;
193195
/*! \brief The maximum number of threads to be used size of a thread warp */
194196
int max_threads_per_block_;
197+
/*! \brief All available async pipeline stages. */
198+
std::vector<int> stages;
195199
/*! \brief The logging function */
196200
PackedFunc logger;
197201
/*! \brief The function to overwrite the default condition for applying MultiLevelTiling. */

0 commit comments

Comments
 (0)