Skip to content

Commit 2d7e065

Browse files
committed
[Relax][Transform] Handle identical PrimFunc with distinct VDevice
Prior to this commit, if an `IRModule` contained two expressions, where the types of the arguments differed only by the `VDevice`, these would be legalized to produce a single PrimFunc. This PrimFunc would have the a `tvm::attr::kTarget` annotation specific to one of those expressions, and would be incorrect for use in the other location. This commit updates the `LegalizeOps` transform to handle this case, producing multiple TIR PrimFuncs if required by the `VDevice` annotations.
1 parent 6252fa5 commit 2d7e065

File tree

3 files changed

+236
-8
lines changed

3 files changed

+236
-8
lines changed

src/relax/transform/legalize_ops.cc

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <tvm/relax/op_attr_types.h>
2929
#include <tvm/relax/struct_info.h>
3030
#include <tvm/relax/transform.h>
31+
#include <tvm/tir/transform.h>
3132

3233
namespace tvm {
3334
namespace relax {
@@ -83,7 +84,12 @@ class LegalizeMutator : public ExprMutator {
8384
builder_->UpdateFunction(gv, f);
8485
}
8586
}
86-
return builder_->GetContextIRModule();
87+
IRModule output = builder_->GetContextIRModule();
88+
if (requires_tir_convert_ssa_) {
89+
output = tir::transform::ConvertSSA()(output);
90+
}
91+
92+
return output;
8793
}
8894

8995
private:
@@ -129,7 +135,7 @@ class LegalizeMutator : public ExprMutator {
129135
return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args);
130136
}
131137

132-
Target GetTarget(const Array<StructInfo>& sinfos) {
138+
Optional<Target> GetTarget(const Array<StructInfo>& sinfos) {
133139
for (auto sinfo : sinfos) {
134140
if (const auto* tinfo = sinfo.as<TensorStructInfoNode>()) {
135141
if (tinfo->vdevice.defined()) {
@@ -142,20 +148,90 @@ class LegalizeMutator : public ExprMutator {
142148
return GetTarget(tup_sinfo->fields);
143149
}
144150
}
145-
return Target();
151+
return NullOpt;
146152
}
147153

148154
void SaveTarget(const Expr& expr) {
149155
if (expr->IsInstance<CallNode>()) {
150156
auto call = Downcast<Call>(expr);
151-
auto target = GetTarget(call->sinfo_args);
152-
const GlobalVarNode* gvar_node;
153-
if (target.defined() && (gvar_node = call->args[0].as<GlobalVarNode>())) {
154-
this->tmap_.Set(GetRef<GlobalVar>(gvar_node), target);
157+
158+
if (auto target = GetTarget(call->sinfo_args)) {
159+
if (auto gvar = call->args[0].as<GlobalVar>()) {
160+
this->tmap_.Set(gvar.value(), target.value());
161+
}
155162
}
156163
}
157164
}
158165

166+
Expr BindTarget(Expr expr) {
167+
if (!expr->IsInstance<CallNode>()) {
168+
// FLegalize returned something other than a relax::Call. This
169+
// post-processing only handles cases where legalization
170+
// produces a lowered call node. In principle, this
171+
// post-processing isn't necessary, and FLegalize should already
172+
// have generated vdevice-aware kernels, so hopefully the
173+
// FLegalize implementation did so.
174+
return expr;
175+
}
176+
177+
auto call = Downcast<Call>(expr);
178+
179+
auto vdevice_target = GetTarget(call->sinfo_args);
180+
if (!vdevice_target.defined()) {
181+
// No vdevice annotation is present, so we don't need to apply
182+
// any updates.
183+
return expr;
184+
}
185+
186+
if (call->args.empty()) {
187+
return expr;
188+
}
189+
190+
auto gvar = call->args[0].as<GlobalVar>();
191+
if (!gvar.defined()) {
192+
// This is not a call into a legalized function within the
193+
// current IRModule, so no post-processing is required.
194+
return expr;
195+
}
196+
197+
auto base_func = builder_->GetContextIRModule()->Lookup(gvar.value());
198+
auto opt_prim_func = base_func.as<tir::PrimFunc>();
199+
if (!opt_prim_func) {
200+
// The call is to something other than a PrimFunc. It may be
201+
// another Relax function, in which case the legalization of its
202+
// body will handle any additional target annotations.
203+
return expr;
204+
}
205+
auto prim_func = opt_prim_func.value();
206+
207+
auto func_target = prim_func->GetAttr<Target>(tvm::attr::kTarget);
208+
if (func_target && func_target.value()->kind == vdevice_target.value()->kind) {
209+
// The function already has compatible annotations for the
210+
// target, so no modifications are required.
211+
return expr;
212+
}
213+
214+
// The FLegalize function generated a PrimFunc, but that PrimFunc
215+
// doesn't have annotations compatible with the vdevice required
216+
// by the Relax StructInfo. Update the call to instead call a
217+
// `PrimFunc` with the appropriate target annotation. In the
218+
// future, this may be treated as a bug in the FLegalize
219+
// implementation, rather than expected output from it.
220+
auto new_prim_func = WithAttr(prim_func, tvm::attr::kTarget, vdevice_target.value());
221+
auto new_gvar_name = [&]() -> std::string {
222+
std::stringstream ss;
223+
ss << gvar.value()->name_hint;
224+
ss << "_";
225+
ss << vdevice_target.value()->kind->name;
226+
return ss.str();
227+
}();
228+
auto new_gvar = builder_->AddFunction(new_prim_func, new_gvar_name);
229+
requires_tir_convert_ssa_ = true;
230+
231+
call.CopyOnWrite()->args.Set(0, new_gvar);
232+
return call;
233+
}
234+
159235
Expr VisitExpr_(const CallNode* call) final {
160236
Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call));
161237
static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
@@ -268,8 +344,10 @@ class LegalizeMutator : public ExprMutator {
268344
}
269345
Expr legalized = legalization_func(builder_, visited_call);
270346

347+
legalized = BindTarget(legalized);
348+
271349
// Save the expected target info. into tmap_
272-
SaveTarget(legalized);
350+
// SaveTarget(legalized);
273351

274352
legalized = builder_->Normalize(legalized);
275353

@@ -305,6 +383,7 @@ class LegalizeMutator : public ExprMutator {
305383
Map<String, PackedFunc> cmap_;
306384
/*! \brief The map from GlobalVar of PrimFunc to compilation Target. */
307385
Map<GlobalVar, Target> tmap_;
386+
bool requires_tir_convert_ssa_{false};
308387
/*!
309388
* \brief A boolean value indicating if to print warnings for CallNode whose op's
310389
* legalization function is not registered.

src/tir/transforms/ir_utils.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,42 @@ class IRConvertSSA final : public StmtExprMutator {
246246
return std::move(decl);
247247
}
248248

249+
Stmt VisitStmt_(const BlockNode* op) final {
250+
Block block = GetRef<Block>(op);
251+
252+
// The BlockNode is the point of definition for the IterVar
253+
// instances. These re-defines must be present before visiting
254+
// the body of the BlockNode.
255+
std::vector<ScopedRedefine> redefines;
256+
Array<IterVar> iter_vars = op->iter_vars.Map([&](IterVar iter_var) {
257+
if (defined_.count(iter_var->var.get())) {
258+
redefines.emplace_back(this, iter_var->var);
259+
iter_var.CopyOnWrite()->var = redefines.back().new_var;
260+
} else {
261+
defined_.insert(iter_var->var.get());
262+
}
263+
return iter_var;
264+
});
265+
Array<BufferRegion> reads =
266+
block->reads.Map([&](const auto& region) { return VisitBufferAccess(region); });
267+
Array<BufferRegion> writes =
268+
block->writes.Map([&](const auto& region) { return VisitBufferAccess(region); });
269+
270+
if (!reads.same_as(block->reads) || !writes.same_as(block->writes) ||
271+
!iter_vars.same_as(op->iter_vars)) {
272+
auto write_ptr = block.CopyOnWrite();
273+
write_ptr->reads = reads;
274+
write_ptr->writes = writes;
275+
write_ptr->iter_vars = iter_vars;
276+
}
277+
278+
Stmt output = Downcast<Block>(StmtExprMutator::VisitStmt_(block.get()));
279+
280+
while (redefines.size()) redefines.pop_back();
281+
282+
return output;
283+
}
284+
249285
template <typename Node>
250286
Node VisitBufferAccess(Node node) {
251287
Buffer new_buf = GetRemappedBuffer(node->buffer);

tests/python/relax/test_transform_legalize_ops.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,5 +356,118 @@ def main(
356356
tvm.ir.assert_structural_equal(AfterFirstIter, AfterSecondIter)
357357

358358

359+
def test_legalize_with_vdevice():
360+
"""Legalization may generate kernels for multiple targets
361+
362+
This is a regression test. In previous implementations, Relax
363+
expressions whose argument types differed only by their `vdevice`
364+
would be legalized to use the same `PrimFunc`.
365+
366+
"""
367+
368+
@I.ir_module
369+
class Before:
370+
I.module_global_infos({"vdevice": [I.vdevice("llvm")]})
371+
372+
@R.function
373+
def func_cuda(A: R.Tensor([32, 32], "float32"), B: R.Tensor([32, 32], "float32")):
374+
C = R.add(A, B)
375+
return C
376+
377+
@R.function
378+
def func_llvm(
379+
A: R.Tensor([32, 32], "float32", "llvm"), B: R.Tensor([32, 32], "float32", "llvm")
380+
):
381+
C = R.add(A, B)
382+
return C
383+
384+
@I.ir_module
385+
class Expected:
386+
I.module_global_infos(
387+
{
388+
"vdevice": [
389+
I.vdevice(
390+
{
391+
"keys": ["cpu"],
392+
"kind": "llvm",
393+
"mtriple": "x86_64-pc-linux-gnu",
394+
"tag": "",
395+
},
396+
0,
397+
"global",
398+
)
399+
]
400+
}
401+
)
402+
403+
@T.prim_func(private=True)
404+
def add(
405+
A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
406+
B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
407+
T_add: T.Buffer((T.int64(32), T.int64(32)), "float32"),
408+
):
409+
T.func_attr({"tir.noalias": T.bool(True)})
410+
# with T.block("root"):
411+
for ax0, ax1 in T.grid(T.int64(32), T.int64(32)):
412+
with T.block("T_add"):
413+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
414+
T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
415+
T.writes(T_add[v_ax0, v_ax1])
416+
T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]
417+
418+
@T.prim_func(private=True)
419+
def add_llvm(
420+
A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
421+
B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
422+
T_add: T.Buffer((T.int64(32), T.int64(32)), "float32"),
423+
):
424+
T.func_attr(
425+
{
426+
"target": T.target(
427+
{
428+
"keys": ["cpu"],
429+
"kind": "llvm",
430+
"mtriple": "x86_64-pc-linux-gnu",
431+
"tag": "",
432+
}
433+
),
434+
"tir.noalias": T.bool(True),
435+
}
436+
)
437+
# with T.block("root"):
438+
for ax0, ax1 in T.grid(T.int64(32), T.int64(32)):
439+
with T.block("T_add"):
440+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
441+
T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
442+
T.writes(T_add[v_ax0, v_ax1])
443+
T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]
444+
445+
@R.function
446+
def func_cuda(
447+
A: R.Tensor((32, 32), dtype="float32"), B: R.Tensor((32, 32), dtype="float32")
448+
) -> R.Tensor((32, 32), dtype="float32"):
449+
cls = Expected
450+
C = R.call_tir(cls.add, (A, B), out_sinfo=R.Tensor((32, 32), dtype="float32"))
451+
return C
452+
453+
@R.function
454+
def func_llvm(
455+
A: R.Tensor((32, 32), dtype="float32", vdevice="llvm:0"),
456+
B: R.Tensor((32, 32), dtype="float32", vdevice="llvm:0"),
457+
) -> R.Tensor((32, 32), dtype="float32", vdevice="llvm:0"):
458+
cls = Expected
459+
C = R.call_tir(
460+
cls.add_llvm,
461+
(A, B),
462+
out_sinfo=R.Tensor((32, 32), dtype="float32", vdevice="llvm:0"),
463+
)
464+
return C
465+
466+
with tvm.target.Target("cuda"):
467+
After = tvm.relax.transform.LegalizeOps()(Before)
468+
469+
tvm.ir.assert_structural_equal(Expected, After)
470+
471+
359472
if __name__ == "__main__":
360473
tvm.testing.main()

0 commit comments

Comments
 (0)