Skip to content

Commit 17b0b50

Browse files
cblmemoyongwww
authored andcommitted
[MetaSchedule] Introduce Async Pipeline in MultiLevelTiling (apache#14009)
This PR introduces async pipeline in the current TVM's MultiLevelTiling Rules. This PR is based on apache#13966, which is already merged. This is because some conv2d workload will use `tir.if_then_else` to pad the input to the correct size, and this PR uses async copy in such copy statement. 1. Add a subrule in `src/meta_schedule/schedule_rule/multi_level_tiling.h/.cc` that annotate async copy for mlt in supported arch (>= sm80). In CUDA Core, this PR has a perf boost of around 1T GFLOP/s in most Conv2d test cases and 1T ~ 2T in most GEMM test cases. All generated codes, scripts, and traces are available at https://github.com/Rainy-Memory/tvm-async-rule-benchmark. Currently tested on commit `afbfb7aa7e43732cb716f8e443df696110be6afc` in conv2d NHWC workload, with a RTX 3080 GPU. **Notice: given the stochastic nature of evolutionary search, perfromance might become worse if enable this PR.** Workload: Conv2d NHWC |Shape|Mainline TVM|Mainline TVM with Async|Performance Boost| |-|-|-|-| |N=1_H=224_W=224_C=3_K=64_R=7_S=7_STR=2_PAD=3_DIL=1|13838.05219|14687.89452|6.141343581679319%| |N=1_H=56_W=56_C=64_K=64_R=1_S=1_STR=1_PAD=0_DIL=1|5398.305085|5613.892553|3.9936140067192905%| |N=1_H=56_W=56_C=64_K=64_R=3_S=3_STR=1_PAD=1_DIL=1|11652.96825|13157.88249|12.91442839038028%| |N=1_H=56_W=56_C=64_K=256_R=1_S=1_STR=1_PAD=0_DIL=1|10638.8309|11674.68499|9.736540600527816%| |N=1_H=56_W=56_C=256_K=64_R=1_S=1_STR=1_PAD=0_DIL=1|8692.32829|9469.264089|8.938178277203573%| |N=1_H=56_W=56_C=256_K=128_R=1_S=1_STR=2_PAD=0_DIL=1|4685.767442|5698.19634|21.606469175684712%| |N=1_H=28_W=28_C=128_K=128_R=3_S=3_STR=1_PAD=1_DIL=1|9872.787087|10404.60405|5.38669535070061%| |N=1_H=28_W=28_C=128_K=512_R=1_S=1_STR=1_PAD=0_DIL=1|9974.281496|10073.31657|0.9929043414276753%| |N=1_H=28_W=28_C=512_K=128_R=1_S=1_STR=1_PAD=0_DIL=1|7075.866932|8564.572712|21.039199780135142%| |N=1_H=28_W=28_C=512_K=256_R=1_S=1_STR=2_PAD=0_DIL=1|3648.330914|4021.923142|10.240086132713124%| |N=1_H=14_W=14_C=256_K=256_R=3_S=3_STR=1_PAD=1_DIL=1|8192.954618|9160.182054|11.805599824451525%| |N=1_H=14_W=14_C=256_K=1024_R=1_S=1_STR=1_PAD=0_DIL=1|8008.870153|9362.825279|16.90569456283206%| |N=1_H=14_W=14_C=1024_K=256_R=1_S=1_STR=1_PAD=0_DIL=1|5210.062241|6051.208379|16.144646629759908%| |N=1_H=14_W=14_C=1024_K=512_R=1_S=1_STR=2_PAD=0_DIL=1|2550.787202|3587.902938|40.65865373586739%| |N=1_H=7_W=7_C=512_K=512_R=3_S=3_STR=1_PAD=1_DIL=1|4350.626084|5432.788068|24.873706981617943%| |N=1_H=7_W=7_C=512_K=2048_R=1_S=1_STR=1_PAD=0_DIL=1|6672.068026|7663.725217|14.862815953549454%| |N=1_H=7_W=7_C=2048_K=512_R=1_S=1_STR=1_PAD=0_DIL=1|3142.564263|4297.988014|36.766909259541826%| Workload: GEMM NN |Shape|Mainline TVM|Mainline TVM with Async|Performance Boost| |-|-|-|-| |M=512_N=256_K=640|8678.46|10607.37|22.226408832903555%| |M=512_N=384_K=256|8109.13|10290.72|26.902886006267003%| |M=512_N=512_K=512|11419.83|14000.86|22.601299669084398%| |M=512_N=3072_K=768|19709.39|18351.61|-6.8890006235606425%| |M=512_N=768_K=3072|12844.59|13730.88|6.90010346768561%| |M=896_N=896_K=896|16149.91|16131.39|-0.11467556165947945%| |M=1024_N=1024_K=1024|18842.11|19662.8|4.355616223448428%| |M=1152_N=1152_K=1152|15386.79|16736.1|8.769275462913303%| |M=1536_N=1536_K=1536|18522.67|18872.06|1.88628313304725%| |M=2048_N=2048_K=2048|19515.42|18874.85|-3.282378754851291%| |M=3072_N=3072_K=3072|19233.9|19291.42|0.2990553137948975%| |M=4096_N=4096_K=4096|17122.17|19259.01|12.479960191961652%|
1 parent 083d4dd commit 17b0b50

File tree

6 files changed

+405
-5
lines changed

6 files changed

+405
-5
lines changed

src/meta_schedule/schedule_rule/multi_level_tiling.cc

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,23 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context)
8787
TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target";
8888
}
8989
}
90+
if (Optional<String> opt_sm = context->target.value()->GetAttr<String>("arch")) {
91+
std::string sm = opt_sm.value();
92+
if (support::StartsWith(sm, "sm_")) {
93+
sm = sm.substr(3);
94+
try {
95+
// only sm_80 or higher supports async memcopy
96+
if (std::stoi(sm) >= 80) {
97+
// only stage = 4 & 5 is tested. all integer that is bigger than 2
98+
// is theoretically feasible, but no guarantee for great performance.
99+
this->stages.insert(this->stages.end(), {4, 5});
100+
}
101+
} catch (const std::invalid_argument& e) {
102+
LOG(WARNING) << "ValueError: Unable to parse `target.arch`: " << sm
103+
<< ". Details: " << e.what();
104+
}
105+
}
106+
}
90107
logger = context->logger;
91108
}
92109

@@ -115,6 +132,8 @@ std::vector<State> MultiLevelTilingNode::ApplySubRules(std::vector<State> states
115132
states = SubRule(std::move(states), [&](State state) { return TileLoopNest(std::move(state)); });
116133
states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(std::move(state)); });
117134
states = SubRule(std::move(states), [&](State state) { return AddReadReuse(std::move(state)); });
135+
states =
136+
SubRule(std::move(states), [&](State state) { return AddAsyncPipeline(std::move(state)); });
118137
return states;
119138
}
120139

@@ -280,6 +299,43 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
280299
return results;
281300
}
282301

302+
std::vector<State> MultiLevelTilingNode::AddAsyncPipeline(State state) const {
303+
// For arch that does not support async pipeline, this->stages will be an empty vector
304+
if (r_indices_.size() < 1 || this->stages.empty()) {
305+
return {state};
306+
}
307+
// Current only support default config used by ScheduleRule::DefaultCUDA
308+
// @see src/meta_schedule/schedule_rule/schedule_rule.cc
309+
// check the reduce loop contains exactly 3 for loops
310+
// therefore it matches the notation array size in the following code
311+
tir::StmtSRef r_loop_sref = state->sch->GetSRef(state->tiles[r_indices_[0]].back());
312+
const tir::ForNode* r_for_loop = TVM_SREF_TO_FOR(r_loop_sref);
313+
Array<tir::Stmt> seq = Downcast<tir::SeqStmt>(r_for_loop->body)->seq;
314+
if (seq.size() != 3) {
315+
return {state};
316+
}
317+
for (auto& stmt : seq) {
318+
if (!stmt.as<tir::ForNode>()) {
319+
return {state};
320+
}
321+
}
322+
323+
std::vector<State> ret;
324+
ret.push_back(state);
325+
for (int stage : this->stages) {
326+
State new_state = state->Copy();
327+
LoopRV r_loop_fused = new_state->sch->Fuse(new_state->tiles[r_indices_[0]]);
328+
new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_stage,
329+
Array<Integer>{0, 0, stage - 2});
330+
new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_order,
331+
Array<Integer>{0, 1, 2});
332+
new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_async_stages,
333+
Array<Integer>{0});
334+
ret.push_back(std::move(new_state));
335+
}
336+
return ret;
337+
}
338+
283339
void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch,
284340
const tir::BlockRV& block) const {
285341
// Filter out invalid vector lanes according to the data type.

src/meta_schedule/schedule_rule/multi_level_tiling.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
148148
std::vector<State> TileLoopNest(State state) const;
149149
// SubRule 3. add read cache
150150
std::vector<State> AddReadReuse(State state) const;
151+
// SubRule 4. add async pipeline
152+
std::vector<State> AddAsyncPipeline(State state) const;
151153

152154
// Do nothing; Inherited from ScheduleRuleNode
153155
void InitializeWithTuneContext(const TuneContext& context) final;
@@ -192,6 +194,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
192194
int thread_warp_size_;
193195
/*! \brief The maximum number of threads to be used size of a thread warp */
194196
int max_threads_per_block_;
197+
/*! \brief All available async pipeline stages. */
198+
std::vector<int> stages;
195199
/*! \brief The logging function */
196200
PackedFunc logger;
197201
/*! \brief The function to overwrite the default condition for applying MultiLevelTiling. */

tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def cuda_matmul_0(
365365
actual = generate_design_space(
366366
kind="cuda",
367367
mod=mod,
368-
target=Target("nvidia/geforce-rtx-3080"),
368+
target=Target("nvidia/geforce-rtx-2080"), # disable async trace using sm75
369369
types=ms.schedule_rule.MultiLevelTiling,
370370
)
371371
check_sketches(
@@ -483,7 +483,7 @@ def cuda_matmul_relu_0(
483483
actual = generate_design_space(
484484
kind="cuda",
485485
mod=mod,
486-
target=Target("nvidia/geforce-rtx-3080"),
486+
target=Target("nvidia/geforce-rtx-2080"), # disable async trace using sm75
487487
types=ms.schedule_rule.MultiLevelTiling,
488488
)
489489
check_sketches(
@@ -723,7 +723,7 @@ def cache_read_specify_consumer_0(
723723
space = generate_design_space(
724724
kind="cuda",
725725
mod=mod,
726-
target=Target("nvidia/geforce-rtx-3080"),
726+
target=Target("nvidia/geforce-rtx-2080"), # disable async trace using sm75
727727
types=ms.schedule_rule.MultiLevelTiling,
728728
)
729729
check_sketches(

tests/python/unittest/test_meta_schedule_space_cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
def _target():
30-
return Target("nvidia/geforce-rtx-3070")
30+
return Target("nvidia/geforce-rtx-2080") # disable async trace using sm75
3131

3232

3333
def _design_space(mod):

0 commit comments

Comments
 (0)