2626#include < tvm/tir/stmt.h>
2727#include < tvm/tir/stmt_functor.h>
2828
29+ #include < exception>
30+ #include < optional>
31+ #include < tuple>
32+ #include < variant>
33+
2934#include " ../ir/functor_common.h"
35+ #include " ../ir/tir_visitor_with_path.h"
3036#include " tvm/ir/module.h"
3137
3238namespace tvm {
3339namespace tir {
3440
41+ namespace {
42+
43+ template <typename DerivedVerifier>
44+ class Verifier : protected TIRVisitorWithPath {
45+ public:
46+ template <typename TirNodeRef>
47+ static bool Verify (const TirNodeRef& node, bool assert_on_error) {
48+ DerivedVerifier verifier (assert_on_error);
49+ verifier (node);
50+ return !verifier.has_error_ ;
51+ }
52+
53+ protected:
54+ explicit Verifier (bool assert_on_error) : assert_on_error_(assert_on_error) {}
55+
56+ /* \brief Helper class to handle the bool-or-assert handles
57+ *
58+ * Each verifier can either return a boolean, or assert on failure.
59+ * To avoid needing to duplicate this logic at every step, the
60+ * Verify() method can be used. Similar to `LOG(FATAL)` or
61+ * `LOG(DEBUG)`, it returns an object that can accept streamed
62+ * context information.
63+ *
64+ * If the error should be raised, then the context is collected
65+ * identically to `LOG(FATAL)`. If a boolean is returned, or if the
66+ * condition passes, then the streamed context is discarded.
67+ *
68+ * Usage:
69+ *
70+ * Verify(value == expected_value)
71+ * << "ValueError: " << value
72+ * << " was not the expected value of " << expected_value;
73+ */
74+ class VerifyStream {
75+ public:
76+ explicit VerifyStream (bool log_fatal) {
77+ if (log_fatal) {
78+ log_.emplace ();
79+ }
80+ }
81+
82+ VerifyStream (const VerifyStream&) = delete ;
83+ VerifyStream& operator =(const VerifyStream&) = delete ;
84+ VerifyStream (VerifyStream&& other) { std::swap (log_, other.log_ ); }
85+ VerifyStream& operator =(VerifyStream&& other) {
86+ std::swap (log_, other.log_ );
87+ return *this ;
88+ }
89+
90+ template <typename T>
91+ VerifyStream& operator <<(T&& t) {
92+ if (log_.has_value ()) {
93+ log_.value () << std::forward<T>(t);
94+ }
95+ return *this ;
96+ }
97+
98+ ~VerifyStream () noexcept (false ) {
99+ if (log_.has_value ()) {
100+ LOG (FATAL) << log_->str ();
101+ }
102+ }
103+
104+ std::optional<std::ostringstream> log_{std::nullopt };
105+ };
106+
107+ // TODO(Lunderberg): Add the filename/linenum with
108+ // std::source_location when C++20 is available.
109+ VerifyStream Verify (bool condition) {
110+ has_error_ = has_error_ || !condition;
111+ return VerifyStream (!condition && assert_on_error_);
112+ }
113+
114+ bool assert_on_error_;
115+ bool has_error_{false };
116+ };
117+
118+ } // namespace
119+
35120/* ! \brief Verify all Expr inside the block does not contain:
36121 * 1. loop vars outside the current block.
37122 * 2. block vars of parent blocks.
@@ -135,10 +220,135 @@ class BlockVarAccessVerifier : public StmtExprVisitor {
135220 bool has_error_{false };
136221};
137222
223+ class UndefinedVarVerifier : public Verifier <UndefinedVarVerifier> {
224+ public:
225+ // Until templated-this arrives in C++23, the CRTP can't inject a
226+ // constructor into the child class. Therefore, must explicitly add
227+ // the constructor.
228+ using Verifier::Verifier;
229+
230+ private:
231+ void Visit (const PrimFunc& prim_func, ObjectPath path) override {
232+ Verifier::Visit (prim_func, path);
233+ redefine_allowed_within_function_.clear ();
234+ }
235+
236+ void EnterDef (const IterVar& iter_var, ObjectPath path) override {
237+ Verifier::EnterDef (iter_var, path);
238+ if (iter_var->iter_type == IterVarType::kThreadIndex ) {
239+ redefine_allowed_within_function_.insert (iter_var->var );
240+ }
241+ }
242+
243+ void EnterDef (const Var& var, ObjectPath path) override {
244+ bool redefine_is_allowed = redefine_allowed_within_function_.count (var);
245+ {
246+ auto it = currently_defined_.find (var);
247+ Verify (it == currently_defined_.end () || redefine_is_allowed)
248+ << " ValueError: "
249+ << " TIR is ill-formed, "
250+ << " due to multiple nested definitions of variable " << var
251+ << " . It was first defined at " << it->second << " , and was re-defined at " << path;
252+ }
253+
254+ {
255+ auto it = previously_defined_.find (var);
256+ Verify (it == previously_defined_.end () || redefine_is_allowed)
257+ << " ValueError: "
258+ << " TIR is ill-formed, "
259+ << " due to multiple definitions of variable " << var << " . It was first defined at "
260+ << it->second << " , and was later re-defined at " << path;
261+ }
262+
263+ currently_defined_.insert ({var, path});
264+ }
265+
266+ void ExitDef (const Var& var, ObjectPath path) override {
267+ auto active_def = currently_defined_.find (var);
268+
269+ currently_defined_.erase (active_def);
270+ previously_defined_.insert ({var, path});
271+ }
272+
273+ void VisitExpr_ (const VarNode* op, ObjectPath path) override {
274+ auto var = GetRef<Var>(op);
275+
276+ auto active_def = currently_defined_.find (var);
277+ auto verify = Verify (active_def != currently_defined_.end ());
278+ verify << " ValueError: "
279+ << " Invalid use of undefined variable " << var << " at " << path << " ." ;
280+
281+ // Check if there was a previous definition, and append the
282+ // location to the error message if there was. This is to aid in
283+ // debugging, by distinguishing between a variable that is
284+ // currently out-of-scope, and a variable that never had a
285+ // definition in the first place.
286+ if (auto prev_def = previously_defined_.find (var); prev_def != previously_defined_.end ()) {
287+ verify << " . While this variable was previously defined at " << prev_def->second
288+ << " , this definition is no longer in-scope." ;
289+ }
290+ }
291+
292+ // Variables that are defined in the currently-visited scope.
293+ std::unordered_map<Var, ObjectPath, ObjectPtrHash, ObjectPtrEqual> currently_defined_;
294+
295+ // Variables that were previously defined, and are now out of scope.
296+ std::unordered_map<Var, ObjectPath, ObjectPtrHash, ObjectPtrEqual> previously_defined_;
297+
298+ // Special variables that are allowed to be re-defined, so long as
299+ // that re-definition occurs within the same PrimFunc. For example
300+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> redefine_allowed_within_function_;
301+ };
302+
303+ /* \brief Verify unique tir::Var for each environment thread
304+ *
305+ * Environment threads, such as CUDA's `threadIdx.x`, are defined in
306+ * TIR using an `AttrStmt` with the key `attr::thread_extent`. A
307+ * `PrimFunc` may contain multiple such attributes for the same
308+ * environment thread. However, all such attributes must use the same
309+ * `tir::Var` for a given thread.
310+ */
311+ class SingleEnvThreadVerifier : public Verifier <SingleEnvThreadVerifier> {
312+ public:
313+ using Verifier::Verifier;
314+
315+ private:
316+ void Visit (const PrimFunc& prim_func, ObjectPath path) override {
317+ Verifier::Visit (prim_func, path);
318+ env_thread_vars_.clear ();
319+ }
320+
321+ void EnterDef (const IterVar& iter_var, ObjectPath path) override {
322+ if (iter_var->iter_type == IterVarType::kThreadIndex ) {
323+ if (auto it = env_thread_vars_.find (iter_var->thread_tag ); it != env_thread_vars_.end ()) {
324+ const auto & [prev_var, prev_path] = it->second ;
325+ Verify (prev_var.same_as (iter_var->var ))
326+ << " ValueError: "
327+ << " PrimFunc uses multiple distinct TIR variables "
328+ << " for the environment thread \" " << iter_var->thread_tag << " \" . "
329+ << " While multiple tir::AttrStmt may define the same environment thread, "
330+ << " all definitions within a single PrimFunc must share the same tir::Var. "
331+ << " Binding of environment thread \" " << iter_var->thread_tag
332+ << " \" to the TIR variable " << iter_var->var << " at " << path
333+ << " conflicts with the previous binding to the TIR variable " << prev_var << " at "
334+ << path;
335+ } else {
336+ env_thread_vars_.insert ({iter_var->thread_tag , {iter_var->var , path}});
337+ }
338+ }
339+ }
340+
341+ std::unordered_map<String, std::tuple<Var, ObjectPath>> env_thread_vars_;
342+ };
343+
138344bool VerifyWellFormed (const PrimFunc& func, bool assert_mode) {
139345 if (!BlockVarAccessVerifier::Verify (func, assert_mode)) {
140346 return false ;
141347 }
348+
349+ if (!UndefinedVarVerifier::Verify (func, assert_mode)) return false ;
350+ if (!SingleEnvThreadVerifier::Verify (func, assert_mode)) return false ;
351+
142352 // TODO(Siyuan): add more checks here.
143353 return true ;
144354}
@@ -152,6 +362,10 @@ bool VerifyWellFormed(const IRModule& mod, bool assert_mode) {
152362 }
153363 }
154364 }
365+
366+ if (!UndefinedVarVerifier::Verify (mod, assert_mode)) return false ;
367+ if (!SingleEnvThreadVerifier::Verify (mod, assert_mode)) return false ;
368+
155369 return true ;
156370}
157371
0 commit comments