Skip to content

Commit f8c8678

Browse files
vinx13masahi
authored andcommitted
[MetaSchedule] Handle 'warp_execution' implied extend of threadIdx.x in VerifyGpuCode (apache#11949)
1 parent 9f7e36f commit f8c8678

File tree

3 files changed

+417
-45
lines changed

3 files changed

+417
-45
lines changed

include/tvm/tir/stmt.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,6 +1525,12 @@ constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensori
15251525
/*! \brief Mark that a block is a preprocessor block for layout rewrite. */
15261526
constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";
15271527

1528+
/*!
1529+
* \brief Mark that a block is executed by a warp. This implies the extend of threadIdx.x is
1530+
* warp size.
1531+
*/
1532+
constexpr const char* warp_execution = "warp_execution";
1533+
15281534
/*!
15291535
* \brief Check if attr_key is a pragma key extension
15301536
* \param attr_key The attr key to be compared

src/meta_schedule/postproc/verify_gpu_code.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,20 @@ namespace tir {
2525

2626
class ThreadExtentChecker : private StmtVisitor {
2727
public:
28-
static bool Check(const Stmt& stmt) {
28+
static bool Check(const Stmt& stmt, int thread_warp_size) {
2929
try {
30-
ThreadExtentChecker().VisitStmt(stmt);
30+
ICHECK(thread_warp_size > 0);
31+
ThreadExtentChecker checker(thread_warp_size);
32+
checker.VisitStmt(stmt);
3133
return true;
3234
} catch (const dmlc::Error& e) {
3335
return false;
3436
}
3537
}
3638

3739
private:
40+
explicit ThreadExtentChecker(int thread_warp_size) : thread_warp_size_(thread_warp_size) {}
41+
3842
void VisitStmt_(const ForNode* loop) {
3943
runtime::ThreadScope thread_scope = GetThreadScope(loop);
4044
if (IsThreadIdx(thread_scope)) {
@@ -64,6 +68,10 @@ class ThreadExtentChecker : private StmtVisitor {
6468
}
6569

6670
void VisitStmt_(const BlockNode* block) {
71+
int old_thread_idx_x = thread_idx_x;
72+
if (block->annotations.count(attr::warp_execution)) {
73+
thread_idx_x = thread_warp_size_;
74+
}
6775
if (Optional<Integer> low_inclusive =
6876
GetAnn<Integer>(block, attr::meta_schedule_thread_extent_low_inclusive)) {
6977
if (Optional<Integer> high_inclusive =
@@ -77,11 +85,13 @@ class ThreadExtentChecker : private StmtVisitor {
7785
}
7886
}
7987
StmtVisitor::VisitStmt_(block);
88+
thread_idx_x = old_thread_idx_x;
8089
}
8190

8291
int64_t thread_idx_x = 1;
8392
int64_t thread_idx_y = 1;
8493
int64_t thread_idx_z = 1;
94+
int thread_warp_size_ = -1;
8595
};
8696

8797
} // namespace tir
@@ -104,6 +114,7 @@ Integer Extract(const Target& target, const char* name) {
104114
class VerifyGPUCodeNode : public PostprocNode {
105115
public:
106116
Map<String, PrimExpr> target_constraints_{nullptr};
117+
int thread_warp_size_ = -1;
107118

108119
void InitializeWithTuneContext(const TuneContext& context) final {
109120
ICHECK(context->target.defined());
@@ -114,6 +125,7 @@ class VerifyGPUCodeNode : public PostprocNode {
114125
{"max_vthread", Integer(8)},
115126
{"max_vector_bytes", Integer(16)},
116127
};
128+
thread_warp_size_ = Extract(target, "thread_warp_size");
117129
}
118130

119131
bool Verify(const IRModule& mod) const {
@@ -133,7 +145,7 @@ class VerifyGPUCodeNode : public PostprocNode {
133145
const GlobalVar& g_var = kv.first;
134146
const BaseFunc& base_func = kv.second;
135147
if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
136-
if (!tir::ThreadExtentChecker::Check(prim_func->body)) {
148+
if (!tir::ThreadExtentChecker::Check(prim_func->body, thread_warp_size_)) {
137149
return false;
138150
}
139151
IRModule lowered{nullptr};

0 commit comments

Comments
 (0)