@@ -87,6 +87,21 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context)
8787 TVM_PY_LOG (INFO, context->logger ) << " 'thread_warp_size' is not defined in the target" ;
8888 }
8989 }
90+ if (Optional<String> opt_sm = context->target .value ()->GetAttr <String>(" arch" )) {
91+ std::string sm = opt_sm.value ();
92+ if (support::StartsWith (sm, " sm_" )) {
93+ sm = sm.substr (3 );
94+ try {
95+ // only sm_80 or higher supports async memcopy
96+ if (std::stoi (sm) >= 80 ) {
97+ this ->stages .insert (this ->stages .end (), {4 , 5 });
98+ }
99+ } catch (const std::invalid_argument& e) {
100+ LOG (WARNING) << " ValueError: Unable to parse `target.arch`: " << sm
101+ << " . Details: " << e.what ();
102+ }
103+ }
104+ }
90105 logger = context->logger ;
91106}
92107
@@ -115,6 +130,9 @@ std::vector<State> MultiLevelTilingNode::ApplySubRules(std::vector<State> states
115130 states = SubRule (std::move (states), [&](State state) { return TileLoopNest (std::move (state)); });
116131 states = SubRule (std::move (states), [&](State state) { return AddWriteReuse (std::move (state)); });
117132 states = SubRule (std::move (states), [&](State state) { return AddReadReuse (std::move (state)); });
133+ states = SubRule (std::move (states), [&](State state) {
134+ return AddAsyncPipeline (std::move (state));
135+ });
118136 return states;
119137}
120138
@@ -280,6 +298,43 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
280298 return results;
281299}
282300
301+ std::vector<State> MultiLevelTilingNode::AddAsyncPipeline (State state) const {
302+ // For arch that does not support async pipeline, this->stages will be an empty vector
303+ if (r_indices_.size () < 1 || this ->stages .empty ()) {
304+ return {state};
305+ }
306+ // Current only support default config used by ScheduleRule::DefaultCUDA
307+ // @see src/meta_schedule/schedule_rule/schedule_rule.cc
308+ // check the reduce loop contains exactly 3 for loops
309+ // therefore it matches the notation array size in the following code
310+ tir::StmtSRef r_loop_sref = state->sch ->GetSRef (state->tiles [r_indices_[0 ]].back ());
311+ const tir::ForNode* r_for_loop = TVM_SREF_TO_FOR (r_loop_sref);
312+ Array<tir::Stmt> seq = Downcast<tir::SeqStmt>(r_for_loop->body )->seq ;
313+ if (seq.size () != 3 ) {
314+ return {state};
315+ }
316+ for (auto & stmt : seq) {
317+ if (!stmt.as <tir::ForNode>()) {
318+ return {state};
319+ }
320+ }
321+
322+ LoopRV r_loop_fused = state->sch ->Fuse (state->tiles [r_indices_[0 ]]);
323+ std::vector<State> ret;
324+ ret.push_back (state);
325+ for (int stage : this ->stages ) {
326+ State new_state = state->Copy ();
327+ new_state->sch ->Annotate (r_loop_fused, tir::attr::software_pipeline_stage,
328+ Array<Integer>{0 , 0 , stage - 2 });
329+ new_state->sch ->Annotate (r_loop_fused, tir::attr::software_pipeline_order,
330+ Array<Integer>{0 , 1 , 2 });
331+ new_state->sch ->Annotate (r_loop_fused, tir::attr::software_pipeline_async_stages,
332+ Array<Integer>{0 });
333+ ret.push_back (std::move (new_state));
334+ }
335+ return ret;
336+ }
337+
283338void MultiLevelTilingNode::AnnotateCooperativeFetching (Schedule* sch,
284339 const tir::BlockRV& block) const {
285340 // Filter out invalid vector lanes according to the data type.
0 commit comments