Skip to content

Commit e8cd33b

Browse files
authored
[TIR] Update SplitHostDevice to post-process with ConvertSSA (#14496)
* [TIR][Utils] Implemented ConvertSSA as IRModule transform When passes create new PrimFuncs, such as when `tir.SplitHostDevice` separates out a `tir::Stmt` into an independent function, the parameters of these new function may alias existing variable definitions. While this is well-defined, because variable definitions are not shared across function boundaries, it can give false discrepancies from `tvm.ir.assert_structural_equal`. This commit implements `tvm::tir::transform::ConvertSSA`, which ensures unique variable declaration locations across an entire module. * [TIR] Update SplitHostDevice to post-process with ConvertSSA Avoid duplicate variable defitions between the host and device PrimFunc.
1 parent 4d59c95 commit e8cd33b

File tree

4 files changed

+205
-21
lines changed

4 files changed

+205
-21
lines changed

include/tvm/tir/transform.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,19 @@ TVM_DLL Pass RewriteUnsafeSelect();
176176
*/
177177
TVM_DLL Pass Simplify();
178178

179+
/*!
180+
* \brief Convert an IRModule to be SSA form.
181+
*
182+
* This pass handles cases where the same tir::Var appears in
183+
* multiple functions within the same module. For example, after
184+
* extracting a fragment from one function into another, where the
185+
* same `tir::Var` may be defined both as within the body of the
186+
* original function, and as a parameter within the hoisted function.
187+
*
188+
* \return The pass.
189+
*/
190+
TVM_DLL Pass ConvertSSA();
191+
179192
/*!
180193
* \brief Instruments bound checkers.
181194
*

src/tir/transforms/ir_utils.cc

Lines changed: 166 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <tvm/arith/analyzer.h>
2727
#include <tvm/arith/int_solver.h>
2828
#include <tvm/tir/stmt_functor.h>
29+
#include <tvm/tir/transform.h>
2930

3031
#include <unordered_map>
3132
#include <unordered_set>
@@ -90,13 +91,107 @@ Stmt MergeNest(const std::vector<std::vector<Stmt>>& nest, Stmt body) {
9091

9192
class IRConvertSSA final : public StmtExprMutator {
9293
public:
93-
PrimExpr VisitExpr_(const VarNode* op) final {
94-
if (scope_.count(op) && !scope_[op].empty()) {
95-
return scope_[op].back();
96-
} else {
97-
return GetRef<PrimExpr>(op);
94+
PrimFunc VisitPrimFunc(PrimFunc func) {
95+
std::vector<ScopedRedefine> redefines;
96+
97+
// Remap parameters, if they were used in another function
98+
auto params = func->params.Map([&](const tir::Var& var) -> tir::Var {
99+
if (defined_.count(var.get())) {
100+
const ScopedRedefine& redefine = redefines.emplace_back(this, var);
101+
return redefine.new_var;
102+
} else {
103+
defined_.insert(var.get());
104+
return var;
105+
}
106+
});
107+
108+
// Remap implicitly defined buffer parameters
109+
{
110+
std::unordered_set<const VarNode*> defined_params;
111+
for (const auto& var : func->params) {
112+
defined_params.insert(var.get());
113+
}
114+
for (const auto& [var, buffer] : func->buffer_map) {
115+
static_cast<void>(var); // gcc 7.x bug, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
116+
auto check_expr = [&](const PrimExpr& expr) {
117+
auto* var_ptr = expr.as<VarNode>();
118+
if (!var_ptr) return;
119+
if (defined_params.count(var_ptr)) return;
120+
121+
if (defined_.count(var_ptr)) {
122+
auto var = GetRef<Var>(var_ptr);
123+
redefines.emplace_back(this, var);
124+
} else {
125+
defined_.insert(var_ptr);
126+
}
127+
};
128+
for (const auto& dim : buffer->shape) {
129+
check_expr(dim);
130+
}
131+
for (const auto& stride : buffer->strides) {
132+
check_expr(stride);
133+
}
134+
check_expr(buffer->elem_offset);
135+
}
136+
}
137+
138+
// Update the buffer map, based on the redefined parameters
139+
auto buffer_map = [&]() {
140+
Map<Var, Buffer> buffer_map;
141+
bool made_change = false;
142+
for (const auto& [var, buffer] : func->buffer_map) {
143+
auto new_var = GetRemappedVar(var);
144+
auto new_buf = GetRemappedBuffer(buffer);
145+
146+
made_change = made_change || !var.same_as(new_var) || !buffer.same_as(new_buf);
147+
buffer_map.Set(new_var, new_buf);
148+
}
149+
if (made_change) {
150+
return buffer_map;
151+
} else {
152+
return func->buffer_map;
153+
}
154+
}();
155+
156+
auto attrs = [&]() -> DictAttrs {
157+
Map<String, ObjectRef> dict;
158+
bool made_change = false;
159+
160+
for (const auto& [key, old_value] : func->attrs->dict) {
161+
auto value = old_value;
162+
if (auto* expr = value.as<PrimExprNode>()) {
163+
value = VisitExpr(GetRef<PrimExpr>(expr));
164+
} else if (auto* stmt = value.as<StmtNode>()) {
165+
value = VisitStmt(GetRef<Stmt>(stmt));
166+
}
167+
168+
made_change = made_change || !value.same_as(old_value);
169+
dict.Set(key, value);
170+
}
171+
172+
if (made_change) {
173+
return DictAttrs(dict);
174+
} else {
175+
return func->attrs;
176+
}
177+
}();
178+
179+
auto body = VisitStmt(func->body);
180+
181+
// If anything changed, update the returned function
182+
if (!params.same_as(func->params) || !buffer_map.same_as(func->buffer_map) ||
183+
!attrs.same_as(func->attrs) || !body.same_as(func->body)) {
184+
func = PrimFunc(params, body, func->ret_type, buffer_map, attrs);
185+
}
186+
187+
// Pop the redefines in reverse order of creation
188+
while (redefines.size()) {
189+
redefines.pop_back();
98190
}
191+
return func;
99192
}
193+
194+
PrimExpr VisitExpr_(const VarNode* op) final { return GetRemappedVar(GetRef<Var>(op)); }
100195
PrimExpr VisitExpr_(const LetNode* op) final {
101196
const Var& v = op->var;
102197
if (defined_.count(v.get())) {
@@ -142,18 +237,27 @@ class IRConvertSSA final : public StmtExprMutator {
142237
return node;
143238
}
144239

240+
Var GetRemappedVar(Var var) {
241+
if (auto it = scope_.find(var.get()); it != scope_.end() && it->second.size()) {
242+
return it->second.back();
243+
} else {
244+
return var;
245+
}
246+
}
247+
145248
Buffer GetRemappedBuffer(Buffer buf) {
146249
// Determine the buffer var that should be in the updated buffer,
147250
// given the current scope. If no redefines are present, then the
148251
// buffer var is unchanged.
149-
Var new_buffer_var = buf->data;
150-
auto var_it = scope_.find(buf->data.get());
151-
if (var_it != scope_.end() && !var_it->second.empty()) {
152-
new_buffer_var = var_it->second.back();
153-
}
252+
Var new_buffer_var = GetRemappedVar(buf->data);
253+
PrimExpr elem_offset = VisitExpr(buf->elem_offset);
254+
auto visit_expr = [this](const PrimExpr& expr) { return VisitExpr(expr); };
255+
Array<PrimExpr> shape = buf->shape.Map(visit_expr);
256+
Array<PrimExpr> strides = buf->strides.Map(visit_expr);
154257

155258
// If no mapping is required, return the original buffer.
156-
if (new_buffer_var.same_as(buf->data)) {
259+
if (new_buffer_var.same_as(buf->data) && elem_offset.same_as(buf->elem_offset) &&
260+
shape.same_as(buf->shape) && strides.same_as(buf->strides)) {
157261
return buf;
158262
}
159263

@@ -169,9 +273,9 @@ class IRConvertSSA final : public StmtExprMutator {
169273
// new buffer, pushing it onto the scoped stack of existing
170274
// buffers. This will be popped when the new_buffer_var
171275
// redefinition is popped.
172-
Buffer new_buf(new_buffer_var, buf->dtype, buf->shape, buf->strides, buf->elem_offset,
173-
buf->name, buf->data_alignment, buf->offset_factor, buf->buffer_type,
174-
buf->axis_separators, buf->span);
276+
Buffer new_buf(new_buffer_var, buf->dtype, shape, strides, elem_offset, buf->name,
277+
buf->data_alignment, buf->offset_factor, buf->buffer_type, buf->axis_separators,
278+
buf->span);
175279
buffers.push_back(new_buf);
176280
return new_buf;
177281
}
@@ -239,16 +343,33 @@ class IRConvertSSA final : public StmtExprMutator {
239343
}
240344

241345
~ScopedRedefine() {
242-
parent->scope_[old_var.get()].pop_back();
243-
for (auto& kv : parent->buf_remap_) {
244-
std::vector<Buffer>& buffers = kv.second;
245-
if (buffers.size() && (buffers.back()->data.get() == new_var.get())) {
246-
buffers.pop_back();
346+
if (parent) {
347+
parent->scope_[old_var.get()].pop_back();
348+
for (auto& kv : parent->buf_remap_) {
349+
std::vector<Buffer>& buffers = kv.second;
350+
if (buffers.size() && (buffers.back()->data.get() == new_var.get())) {
351+
buffers.pop_back();
352+
}
247353
}
248354
}
249355
}
250356

251-
IRConvertSSA* parent;
357+
ScopedRedefine& operator=(const ScopedRedefine&) = delete;
358+
ScopedRedefine(const ScopedRedefine&) = delete;
359+
360+
ScopedRedefine& operator=(ScopedRedefine&& other) {
361+
swap(other);
362+
return *this;
363+
}
364+
ScopedRedefine(ScopedRedefine&& other) { swap(other); }
365+
366+
void swap(ScopedRedefine& other) {
367+
std::swap(parent, other.parent);
368+
std::swap(old_var, other.old_var);
369+
std::swap(new_var, other.new_var);
370+
}
371+
372+
IRConvertSSA* parent{nullptr};
252373
Var old_var;
253374
Var new_var;
254375
};
@@ -447,5 +568,30 @@ std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const AttrStmtNode* op) {
447568
return std::make_pair(op->value, inner->value);
448569
}
449570

571+
namespace transform {
572+
Pass ConvertSSA() {
573+
auto pass_func = [](IRModule mod, PassContext ctx) {
574+
tir::IRConvertSSA converter;
575+
Map<GlobalVar, BaseFunc> functions;
576+
bool made_change = false;
577+
for (auto [gvar, base_func] : mod->functions) {
578+
if (auto* ptr = base_func.as<tir::PrimFuncNode>()) {
579+
auto updated = converter.VisitPrimFunc(GetRef<tir::PrimFunc>(ptr));
580+
if (!updated.same_as(base_func)) {
581+
made_change = true;
582+
base_func = updated;
583+
}
584+
}
585+
functions.Set(gvar, base_func);
586+
}
587+
if (made_change) {
588+
mod.CopyOnWrite()->functions = std::move(functions);
589+
}
590+
return mod;
591+
};
592+
return tvm::transform::CreateModulePass(pass_func, 0, "tir.ConvertSSA", {});
593+
}
594+
595+
} // namespace transform
450596
} // namespace tir
451597
} // namespace tvm

src/tir/transforms/split_host_device.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ Pass SplitHostDevice() {
282282
}
283283
}
284284
mod->Update(device_mod);
285-
return mod;
285+
return ConvertSSA()(mod);
286286
};
287287

288288
return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice", {});

tests/python/unittest/test_tir_transform_split_host_device.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import tvm
1818
from tvm import te
1919
import tvm.testing
20+
from tvm.script import tir as T, ir as I
2021

2122

2223
@tvm.testing.requires_cuda
@@ -48,5 +49,29 @@ def test_split_host_device_func_attr():
4849
assert fdevice.attrs["tir.is_global_func"].value
4950

5051

52+
def test_ssa_across_entire_module():
53+
"""The host and device functions should not share TIR vars
54+
55+
Any arguments that are passed from the host to the device should
56+
be in terms of independent TIR variables.
57+
"""
58+
59+
@I.ir_module
60+
class before:
61+
@T.prim_func
62+
def main():
63+
T.func_attr({"global_symbol": "main", "target": T.target("cuda")})
64+
for i in range(16):
65+
T.attr(0, "device_scope", 0)
66+
for j in range(16):
67+
T.evaluate(i)
68+
69+
after = tvm.tir.transform.SplitHostDevice()(before)
70+
loop_var = after["main"].body.loop_var
71+
param_var = after["main_kernel0"].params[0]
72+
73+
assert not loop_var.same_as(param_var)
74+
75+
5176
if __name__ == "__main__":
5277
test_split_host_device_func_attr()

0 commit comments

Comments
 (0)