Skip to content

Commit eb15d04

Browse files
authored
[TIR] In SplitHostDevice, check for variables in thread extents (#16250)
* [TIR] In SplitHostDevice, check for variables in thread extents Otherwise, they would be undefined after being de-duplicated by `ConvertSSA`. * Revert #16236 The buf reported in #16237 can be resolved by tracking variable usage in a thread extent. * lint fixes * Update TIR well-formed checker for env thread SSA requirements Environment threads must reuse the same `tir::Var` across all `AttrStmt` instances in a `PrimFunc`, but must not reuse across separate `PrimFunc`s in an `IRModule`. * Update ConvertSSA to handle environment threads' SSA requirements * lint fix * Updated docstrings for VerifyWellFormed * Rely on script.Complete for read/writes Avoids issue in cortexm unit tests resulting from read/write annotations being present in the root block, followed by application of BindParams. * Typo fix * Added structural equal comparison in unit test
1 parent 8eec0bf commit eb15d04

File tree

10 files changed

+1383
-33
lines changed

10 files changed

+1383
-33
lines changed

include/tvm/tir/analysis.h

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,13 +307,40 @@ TVM_DLL Map<Buffer, Optional<Stmt>> DetectBufferAccessLCA(const PrimFunc& func);
307307

308308
/*!
309309
* \brief Verify if the given TIR is well-formed. The verification includes:
310-
* - Check if expressions not contain vars that is defined outside the block.
310+
*
311+
* - All variables are defined prior to their point of use.
312+
*
313+
* - No variables are used outside of the scope of their definition.
314+
*
315+
* - Each variable has a single point of definition.
316+
*
317+
* - Expressions within a tir::Block may not reference variables
318+
* defined outside the block. For example, for a block with iter
319+
* vars `vi, vj = T.axis.remap('SS', [i,j])`, the statement
320+
* `B[i,j] = A[i,j]` would be ill-formed, because it uses the loop
321+
* variables `i` and `j` instead of the block variables `vi` and
322+
* `vj`.
323+
*
311324
* \param func The PrimFunc to be verified.
312325
* \param assert_mode The indicator if it raises an error when the function is not well-formed.
313326
* \return Whether it is a well-formed TIR function.
314327
*/
315328
TVM_DLL bool VerifyWellFormed(const PrimFunc& func, bool assert_mode = true);
316329

330+
/*!
331+
* \brief Verify if the TIR in the given IRMOdule is well-formed.
332+
*
333+
* In addition to the checks performed for each PrimFunc (see above),
334+
* the following checks are performed:
335+
*
336+
* - The same TIR variable may not be defined in more than one function
337+
*
338+
* \param mod The IRModule to be verified.
339+
* \param assert_mode The indicator if it raises an error when the function is not well-formed.
340+
* \return Whether it is a well-formed TIR module.
341+
*/
342+
TVM_DLL bool VerifyWellFormed(const IRModule& mod, bool assert_mode = true);
343+
317344
/*!
318345
* \brief Find the entry function of the given IRModule, i.e, functions marked by
319346
* `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc.

src/te/operation/create_primfunc.cc

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -424,15 +424,12 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf
424424
}
425425
}
426426

427-
// Step 3. Collect Access Region
428-
Array<BufferRegion> reads, writes;
429-
for (const te::Tensor& tensor : extern_op->inputs) {
430-
// We have ICHECK before so it is not needed here.
431-
reads.push_back(BufferRegion::FullRegion(info->tensor2buffers[tensor]));
432-
}
433-
for (const Buffer& buffer : extern_op->output_placeholders) {
434-
writes.push_back(BufferRegion::FullRegion(buffer));
435-
}
427+
// The access region does not need to be collected here, as it will
428+
// be generated with the later application of "script.Complete" in
429+
// GenerateAndCompletePrimFunc. Waiting until later also handles
430+
// the case where there is only a single BlockNode, which then
431+
// becomes the root Block of the function, and should not have
432+
// reads/writes filled in.
436433

437434
BufferSubstituter substituter(var_map, input_buffer_map);
438435
Stmt body = substituter(extern_op->body);
@@ -442,8 +439,8 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf
442439
/*predicate=*/Bool(true),
443440
/*block=*/
444441
Block(/*iter_vars=*/{},
445-
/*reads=*/std::move(reads),
446-
/*writes=*/std::move(writes),
442+
/*reads=*/{},
443+
/*writes=*/{},
447444
/*name_hint=*/info->FreshName(extern_op->name),
448445
/*body=*/std::move(body),
449446
/*init=*/NullOpt,

src/tir/analysis/verify_well_formed.cc

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,97 @@
2626
#include <tvm/tir/stmt.h>
2727
#include <tvm/tir/stmt_functor.h>
2828

29+
#include <exception>
30+
#include <optional>
31+
#include <tuple>
32+
#include <variant>
33+
2934
#include "../ir/functor_common.h"
35+
#include "../ir/tir_visitor_with_path.h"
3036
#include "tvm/ir/module.h"
3137

3238
namespace tvm {
3339
namespace tir {
3440

41+
namespace {
42+
43+
template <typename DerivedVerifier>
44+
class Verifier : protected TIRVisitorWithPath {
45+
public:
46+
template <typename TirNodeRef>
47+
static bool Verify(const TirNodeRef& node, bool assert_on_error) {
48+
DerivedVerifier verifier(assert_on_error);
49+
verifier(node);
50+
return !verifier.has_error_;
51+
}
52+
53+
protected:
54+
explicit Verifier(bool assert_on_error) : assert_on_error_(assert_on_error) {}
55+
56+
/* \brief Helper class to handle the bool-or-assert handles
57+
*
58+
* Each verifier can either return a boolean, or assert on failure.
59+
* To avoid needing to duplicate this logic at every step, the
60+
* Verify() method can be used. Similar to `LOG(FATAL)` or
61+
* `LOG(DEBUG)`, it returns an object that can accept streamed
62+
* context information.
63+
*
64+
* If the error should be raised, then the context is collected
65+
* identically to `LOG(FATAL)`. If a boolean is returned, or if the
66+
* condition passes, then the streamed context is discarded.
67+
*
68+
* Usage:
69+
*
70+
* Verify(value == expected_value)
71+
* << "ValueError: " << value
72+
* << " was not the expected value of " << expected_value;
73+
*/
74+
class VerifyStream {
75+
public:
76+
explicit VerifyStream(bool log_fatal) {
77+
if (log_fatal) {
78+
log_.emplace();
79+
}
80+
}
81+
82+
VerifyStream(const VerifyStream&) = delete;
83+
VerifyStream& operator=(const VerifyStream&) = delete;
84+
VerifyStream(VerifyStream&& other) { std::swap(log_, other.log_); }
85+
VerifyStream& operator=(VerifyStream&& other) {
86+
std::swap(log_, other.log_);
87+
return *this;
88+
}
89+
90+
template <typename T>
91+
VerifyStream& operator<<(T&& t) {
92+
if (log_.has_value()) {
93+
log_.value() << std::forward<T>(t);
94+
}
95+
return *this;
96+
}
97+
98+
~VerifyStream() noexcept(false) {
99+
if (log_.has_value()) {
100+
LOG(FATAL) << log_->str();
101+
}
102+
}
103+
104+
std::optional<std::ostringstream> log_{std::nullopt};
105+
};
106+
107+
// TODO(Lunderberg): Add the filename/linenum with
108+
// std::source_location when C++20 is available.
109+
VerifyStream Verify(bool condition) {
110+
has_error_ = has_error_ || !condition;
111+
return VerifyStream(!condition && assert_on_error_);
112+
}
113+
114+
bool assert_on_error_;
115+
bool has_error_{false};
116+
};
117+
118+
} // namespace
119+
35120
/*! \brief Verify all Expr inside the block does not contain:
36121
* 1. loop vars outside the current block.
37122
* 2. block vars of parent blocks.
@@ -135,10 +220,135 @@ class BlockVarAccessVerifier : public StmtExprVisitor {
135220
bool has_error_{false};
136221
};
137222

223+
class UndefinedVarVerifier : public Verifier<UndefinedVarVerifier> {
224+
public:
225+
// Until templated-this arrives in C++23, the CRTP can't inject a
226+
// constructor into the child class. Therefore, must explicitly add
227+
// the constructor.
228+
using Verifier::Verifier;
229+
230+
private:
231+
void Visit(const PrimFunc& prim_func, ObjectPath path) override {
232+
Verifier::Visit(prim_func, path);
233+
redefine_allowed_within_function_.clear();
234+
}
235+
236+
void EnterDef(const IterVar& iter_var, ObjectPath path) override {
237+
Verifier::EnterDef(iter_var, path);
238+
if (iter_var->iter_type == IterVarType::kThreadIndex) {
239+
redefine_allowed_within_function_.insert(iter_var->var);
240+
}
241+
}
242+
243+
void EnterDef(const Var& var, ObjectPath path) override {
244+
bool redefine_is_allowed = redefine_allowed_within_function_.count(var);
245+
{
246+
auto it = currently_defined_.find(var);
247+
Verify(it == currently_defined_.end() || redefine_is_allowed)
248+
<< "ValueError: "
249+
<< "TIR is ill-formed, "
250+
<< "due to multiple nested definitions of variable " << var
251+
<< ". It was first defined at " << it->second << ", and was re-defined at " << path;
252+
}
253+
254+
{
255+
auto it = previously_defined_.find(var);
256+
Verify(it == previously_defined_.end() || redefine_is_allowed)
257+
<< "ValueError: "
258+
<< "TIR is ill-formed, "
259+
<< "due to multiple definitions of variable " << var << ". It was first defined at "
260+
<< it->second << ", and was later re-defined at " << path;
261+
}
262+
263+
currently_defined_.insert({var, path});
264+
}
265+
266+
void ExitDef(const Var& var, ObjectPath path) override {
267+
auto active_def = currently_defined_.find(var);
268+
269+
currently_defined_.erase(active_def);
270+
previously_defined_.insert({var, path});
271+
}
272+
273+
void VisitExpr_(const VarNode* op, ObjectPath path) override {
274+
auto var = GetRef<Var>(op);
275+
276+
auto active_def = currently_defined_.find(var);
277+
auto verify = Verify(active_def != currently_defined_.end());
278+
verify << "ValueError: "
279+
<< "Invalid use of undefined variable " << var << " at " << path << ".";
280+
281+
// Check if there was a previous definition, and append the
282+
// location to the error message if there was. This is to aid in
283+
// debugging, by distinguishing between a variable that is
284+
// currently out-of-scope, and a variable that never had a
285+
// definition in the first place.
286+
if (auto prev_def = previously_defined_.find(var); prev_def != previously_defined_.end()) {
287+
verify << ". While this variable was previously defined at " << prev_def->second
288+
<< ", this definition is no longer in-scope.";
289+
}
290+
}
291+
292+
// Variables that are defined in the currently-visited scope.
293+
std::unordered_map<Var, ObjectPath, ObjectPtrHash, ObjectPtrEqual> currently_defined_;
294+
295+
// Variables that were previously defined, and are now out of scope.
296+
std::unordered_map<Var, ObjectPath, ObjectPtrHash, ObjectPtrEqual> previously_defined_;
297+
298+
// Special variables that are allowed to be re-defined, so long as
299+
// that re-definition occurs within the same PrimFunc. For example
300+
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> redefine_allowed_within_function_;
301+
};
302+
303+
/* \brief Verify unique tir::Var for each environment thread
304+
*
305+
* Environment threads, such as CUDA's `threadIdx.x`, are defined in
306+
* TIR using an `AttrStmt` with the key `attr::thread_extent`. A
307+
* `PrimFunc` may contain multiple such attributes for the same
308+
* environment thread. However, all such attributes must use the same
309+
* `tir::Var` for a given thread.
310+
*/
311+
class SingleEnvThreadVerifier : public Verifier<SingleEnvThreadVerifier> {
312+
public:
313+
using Verifier::Verifier;
314+
315+
private:
316+
void Visit(const PrimFunc& prim_func, ObjectPath path) override {
317+
Verifier::Visit(prim_func, path);
318+
env_thread_vars_.clear();
319+
}
320+
321+
void EnterDef(const IterVar& iter_var, ObjectPath path) override {
322+
if (iter_var->iter_type == IterVarType::kThreadIndex) {
323+
if (auto it = env_thread_vars_.find(iter_var->thread_tag); it != env_thread_vars_.end()) {
324+
const auto& [prev_var, prev_path] = it->second;
325+
Verify(prev_var.same_as(iter_var->var))
326+
<< "ValueError: "
327+
<< "PrimFunc uses multiple distinct TIR variables "
328+
<< " for the environment thread \"" << iter_var->thread_tag << "\". "
329+
<< "While multiple tir::AttrStmt may define the same environment thread, "
330+
<< "all definitions within a single PrimFunc must share the same tir::Var. "
331+
<< "Binding of environment thread \"" << iter_var->thread_tag
332+
<< "\" to the TIR variable " << iter_var->var << " at " << path
333+
<< " conflicts with the previous binding to the TIR variable " << prev_var << " at "
334+
<< path;
335+
} else {
336+
env_thread_vars_.insert({iter_var->thread_tag, {iter_var->var, path}});
337+
}
338+
}
339+
}
340+
341+
std::unordered_map<String, std::tuple<Var, ObjectPath>> env_thread_vars_;
342+
};
343+
138344
bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) {
139345
if (!BlockVarAccessVerifier::Verify(func, assert_mode)) {
140346
return false;
141347
}
348+
349+
if (!UndefinedVarVerifier::Verify(func, assert_mode)) return false;
350+
if (!SingleEnvThreadVerifier::Verify(func, assert_mode)) return false;
351+
142352
// TODO(Siyuan): add more checks here.
143353
return true;
144354
}
@@ -152,6 +362,10 @@ bool VerifyWellFormed(const IRModule& mod, bool assert_mode) {
152362
}
153363
}
154364
}
365+
366+
if (!UndefinedVarVerifier::Verify(mod, assert_mode)) return false;
367+
if (!SingleEnvThreadVerifier::Verify(mod, assert_mode)) return false;
368+
155369
return true;
156370
}
157371

0 commit comments

Comments
 (0)