@@ -25,16 +25,20 @@ namespace tir {
2525
2626class ThreadExtentChecker : private StmtVisitor {
2727 public:
28- static bool Check (const Stmt& stmt) {
28+ static bool Check (const Stmt& stmt, int thread_warp_size ) {
2929 try {
30- ThreadExtentChecker ().VisitStmt (stmt);
30+ ICHECK (thread_warp_size > 0 );
31+ ThreadExtentChecker checker (thread_warp_size);
32+ checker.VisitStmt (stmt);
3133 return true ;
3234 } catch (const dmlc::Error& e) {
3335 return false ;
3436 }
3537 }
3638
3739 private:
40+ explicit ThreadExtentChecker (int thread_warp_size) : thread_warp_size_(thread_warp_size) {}
41+
3842 void VisitStmt_ (const ForNode* loop) {
3943 runtime::ThreadScope thread_scope = GetThreadScope (loop);
4044 if (IsThreadIdx (thread_scope)) {
@@ -64,6 +68,10 @@ class ThreadExtentChecker : private StmtVisitor {
6468 }
6569
6670 void VisitStmt_ (const BlockNode* block) {
71+ int old_thread_idx_x = thread_idx_x;
72+ if (block->annotations .count (attr::warp_execution)) {
73+ thread_idx_x = thread_warp_size_;
74+ }
6775 if (Optional<Integer> low_inclusive =
6876 GetAnn<Integer>(block, attr::meta_schedule_thread_extent_low_inclusive)) {
6977 if (Optional<Integer> high_inclusive =
@@ -77,11 +85,13 @@ class ThreadExtentChecker : private StmtVisitor {
7785 }
7886 }
7987 StmtVisitor::VisitStmt_ (block);
88+ thread_idx_x = old_thread_idx_x;
8089 }
8190
8291 int64_t thread_idx_x = 1 ;
8392 int64_t thread_idx_y = 1 ;
8493 int64_t thread_idx_z = 1 ;
94+ int thread_warp_size_ = -1 ;
8595};
8696
8797} // namespace tir
@@ -104,6 +114,7 @@ Integer Extract(const Target& target, const char* name) {
104114class VerifyGPUCodeNode : public PostprocNode {
105115 public:
106116 Map<String, PrimExpr> target_constraints_{nullptr };
117+ int thread_warp_size_ = -1 ;
107118
108119 void InitializeWithTuneContext (const TuneContext& context) final {
109120 ICHECK (context->target .defined ());
@@ -114,6 +125,7 @@ class VerifyGPUCodeNode : public PostprocNode {
114125 {" max_vthread" , Integer (8 )},
115126 {" max_vector_bytes" , Integer (16 )},
116127 };
128+ thread_warp_size_ = Extract (target, " thread_warp_size" );
117129 }
118130
119131 bool Verify (const IRModule& mod) const {
@@ -133,7 +145,7 @@ class VerifyGPUCodeNode : public PostprocNode {
133145 const GlobalVar& g_var = kv.first ;
134146 const BaseFunc& base_func = kv.second ;
135147 if (const auto * prim_func = base_func.as <tir::PrimFuncNode>()) {
136- if (!tir::ThreadExtentChecker::Check (prim_func->body )) {
148+ if (!tir::ThreadExtentChecker::Check (prim_func->body , thread_warp_size_ )) {
137149 return false ;
138150 }
139151 IRModule lowered{nullptr };
0 commit comments