Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,23 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context)
TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target";
}
}
if (Optional<String> opt_sm = context->target.value()->GetAttr<String>("arch")) {
std::string sm = opt_sm.value();
if (support::StartsWith(sm, "sm_")) {
sm = sm.substr(3);
try {
// only sm_80 or higher supports async memcopy
if (std::stoi(sm) >= 80) {
// only stage = 4 & 5 is tested. all integer that is bigger than 2
// is theoretically feasible, but no guarantee for great performance.
this->stages.insert(this->stages.end(), {4, 5});
}
} catch (const std::invalid_argument& e) {
LOG(WARNING) << "ValueError: Unable to parse `target.arch`: " << sm
<< ". Details: " << e.what();
}
}
}
logger = context->logger;
}

Expand Down Expand Up @@ -115,6 +132,8 @@ std::vector<State> MultiLevelTilingNode::ApplySubRules(std::vector<State> states
states = SubRule(std::move(states), [&](State state) { return TileLoopNest(std::move(state)); });
states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(std::move(state)); });
states = SubRule(std::move(states), [&](State state) { return AddReadReuse(std::move(state)); });
states =
SubRule(std::move(states), [&](State state) { return AddAsyncPipeline(std::move(state)); });
return states;
}

Expand Down Expand Up @@ -280,6 +299,43 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
return results;
}

std::vector<State> MultiLevelTilingNode::AddAsyncPipeline(State state) const {
// For arch that does not support async pipeline, this->stages will be an empty vector
if (r_indices_.size() < 1 || this->stages.empty()) {
return {state};
}
// Current only support default config used by ScheduleRule::DefaultCUDA
// @see src/meta_schedule/schedule_rule/schedule_rule.cc
// check the reduce loop contains exactly 3 for loops
// therefore it matches the notation array size in the following code
tir::StmtSRef r_loop_sref = state->sch->GetSRef(state->tiles[r_indices_[0]].back());
const tir::ForNode* r_for_loop = TVM_SREF_TO_FOR(r_loop_sref);
Array<tir::Stmt> seq = Downcast<tir::SeqStmt>(r_for_loop->body)->seq;
if (seq.size() != 3) {
return {state};
}
for (auto& stmt : seq) {
if (!stmt.as<tir::ForNode>()) {
return {state};
}
}

std::vector<State> ret;
ret.push_back(state);
for (int stage : this->stages) {
State new_state = state->Copy();
LoopRV r_loop_fused = new_state->sch->Fuse(new_state->tiles[r_indices_[0]]);
new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_stage,
Array<Integer>{0, 0, stage - 2});
new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_order,
Array<Integer>{0, 1, 2});
new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_async_stages,
Array<Integer>{0});
ret.push_back(std::move(new_state));
}
return ret;
}

void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch,
const tir::BlockRV& block) const {
// Filter out invalid vector lanes according to the data type.
Expand Down
4 changes: 4 additions & 0 deletions src/meta_schedule/schedule_rule/multi_level_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
std::vector<State> TileLoopNest(State state) const;
// SubRule 3. add read cache
std::vector<State> AddReadReuse(State state) const;
// SubRule 4. add async pipeline
std::vector<State> AddAsyncPipeline(State state) const;

// Do nothing; Inherited from ScheduleRuleNode
void InitializeWithTuneContext(const TuneContext& context) final;
Expand Down Expand Up @@ -192,6 +194,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
int thread_warp_size_;
/*! \brief The maximum number of threads to be used size of a thread warp */
int max_threads_per_block_;
/*! \brief All available async pipeline stages. */
std::vector<int> stages;
/*! \brief The logging function */
PackedFunc logger;
/*! \brief The function to overwrite the default condition for applying MultiLevelTiling. */
Expand Down
6 changes: 3 additions & 3 deletions tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def cuda_matmul_0(
actual = generate_design_space(
kind="cuda",
mod=mod,
target=Target("nvidia/geforce-rtx-3080"),
target=Target("nvidia/geforce-rtx-2080"), # disable async trace using sm75
types=ms.schedule_rule.MultiLevelTiling,
)
check_sketches(
Expand Down Expand Up @@ -483,7 +483,7 @@ def cuda_matmul_relu_0(
actual = generate_design_space(
kind="cuda",
mod=mod,
target=Target("nvidia/geforce-rtx-3080"),
target=Target("nvidia/geforce-rtx-2080"), # disable async trace using sm75
types=ms.schedule_rule.MultiLevelTiling,
)
check_sketches(
Expand Down Expand Up @@ -723,7 +723,7 @@ def cache_read_specify_consumer_0(
space = generate_design_space(
kind="cuda",
mod=mod,
target=Target("nvidia/geforce-rtx-3080"),
target=Target("nvidia/geforce-rtx-2080"), # disable async trace using sm75
types=ms.schedule_rule.MultiLevelTiling,
)
check_sketches(
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_meta_schedule_space_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


def _target():
return Target("nvidia/geforce-rtx-3070")
return Target("nvidia/geforce-rtx-2080") # disable async trace using sm75


def _design_space(mod):
Expand Down
Loading