Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 89 additions & 23 deletions src/relax/transform/legalize_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace relax {
Expand Down Expand Up @@ -74,16 +75,22 @@ class LegalizeMutator : public ExprMutator {
builder_->UpdateFunction(gv, Downcast<BaseFunc>(updated_func));
}
}
// Fill the "kTarget" attribute of PrimFunc
const auto& mod = builder_->GetContextIRModule();
for (const auto& gv : mod->GetGlobalVars()) {
const tir::PrimFuncNode* prim_func;
if (tmap_.count(gv) && (prim_func = mod->Lookup(gv).as<tir::PrimFuncNode>())) {
auto f = WithAttr(GetRef<tir::PrimFunc>(prim_func), tvm::attr::kTarget, tmap_[gv]);
builder_->UpdateFunction(gv, f);
}

IRModule output = builder_->GetContextIRModule();
if (generated_tir_with_target_attr_) {
// It is possible that every call to a legalized PrimFunc
// contains VDevice annotations. In that case, the PrimFunc
// without a target annotation no longer has any callers, and
// should be removed.
output = relax::transform::DeadCodeElimination()(output);

// Avoid accidental sharing of TIR variables in the legalized
// PrimFuncs, when kernels for multiple devices are generated
// from the same PrimFunc.
output = tir::transform::ConvertSSA()(output);
}
return builder_->GetContextIRModule();

return output;
}

private:
Expand Down Expand Up @@ -129,7 +136,7 @@ class LegalizeMutator : public ExprMutator {
return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args);
}

Target GetTarget(const Array<StructInfo>& sinfos) {
Optional<Target> GetTarget(const Array<StructInfo>& sinfos) {
for (auto sinfo : sinfos) {
if (const auto* tinfo = sinfo.as<TensorStructInfoNode>()) {
if (tinfo->vdevice.defined()) {
Expand All @@ -142,18 +149,76 @@ class LegalizeMutator : public ExprMutator {
return GetTarget(tup_sinfo->fields);
}
}
return Target();
return NullOpt;
}

void SaveTarget(const Expr& expr) {
if (expr->IsInstance<CallNode>()) {
auto call = Downcast<Call>(expr);
auto target = GetTarget(call->sinfo_args);
const GlobalVarNode* gvar_node;
if (target.defined() && (gvar_node = call->args[0].as<GlobalVarNode>())) {
this->tmap_.Set(GetRef<GlobalVar>(gvar_node), target);
}
Expr BindTarget(Expr expr) {
if (!expr->IsInstance<CallNode>()) {
// FLegalize returned something other than a relax::Call. This
// post-processing only handles cases where legalization
// produces a lowered call node. In principle, this
// post-processing isn't necessary, and FLegalize should already
// have generated vdevice-aware kernels, so hopefully the
// FLegalize implementation did so.
return expr;
}

auto call = Downcast<Call>(expr);

auto vdevice_target = GetTarget(call->sinfo_args);
if (!vdevice_target.defined()) {
// No vdevice annotation is present, so we don't need to apply
// any updates.
return expr;
}

if (call->args.empty()) {
return expr;
}

auto gvar = call->args[0].as<GlobalVar>();
if (!gvar.defined()) {
// This is not a call into a legalized function within the
// current IRModule, so no post-processing is required.
return expr;
}

auto base_func = builder_->GetContextIRModule()->Lookup(gvar.value());
auto opt_prim_func = base_func.as<tir::PrimFunc>();
if (!opt_prim_func) {
// The call is to something other than a PrimFunc. It may be
// another Relax function, in which case the legalization of its
// body will handle any additional target annotations.
return expr;
}
auto prim_func = opt_prim_func.value();

auto func_target = prim_func->GetAttr<Target>(tvm::attr::kTarget);
if (func_target && func_target.value()->kind == vdevice_target.value()->kind) {
// The function already has compatible annotations for the
// target, so no modifications are required.
return expr;
}

// The FLegalize function generated a PrimFunc, but that PrimFunc
// doesn't have annotations compatible with the vdevice required
// by the Relax StructInfo. Update the call to instead call a
// `PrimFunc` with the appropriate target annotation. In the
// future, this may be treated as a bug in the FLegalize
// implementation, rather than expected output from it.
auto new_prim_func = WithAttr(prim_func, tvm::attr::kTarget, vdevice_target.value());
auto new_gvar_name = [&]() -> std::string {
std::stringstream ss;
ss << gvar.value()->name_hint;
ss << "_";
ss << vdevice_target.value()->kind->name;
return ss.str();
}();
auto new_gvar = builder_->AddFunction(new_prim_func, new_gvar_name);
generated_tir_with_target_attr_ = true;

call.CopyOnWrite()->args.Set(0, new_gvar);
return call;
}

Expr VisitExpr_(const CallNode* call) final {
Expand Down Expand Up @@ -268,8 +333,9 @@ class LegalizeMutator : public ExprMutator {
}
Expr legalized = legalization_func(builder_, visited_call);

// Save the expected target info. into tmap_
SaveTarget(legalized);
// Append the target attribute to any PrimFunc generated in
// legalization.
legalized = BindTarget(legalized);

legalized = builder_->Normalize(legalized);

Expand Down Expand Up @@ -303,8 +369,8 @@ class LegalizeMutator : public ExprMutator {
IRModule mod_;
/*! \brief The customized legalization function map. */
Map<String, PackedFunc> cmap_;
/*! \brief The map from GlobalVar of PrimFunc to compilation Target. */
Map<GlobalVar, Target> tmap_;
/*! \brief If VDevice annotations produced at least one PrimFunc with a Target attr*/
bool generated_tir_with_target_attr_{false};
/*!
* \brief A boolean value indicating if to print warnings for CallNode whose op's
* legalization function is not registered.
Expand Down
36 changes: 36 additions & 0 deletions src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,42 @@ class IRConvertSSA final : public StmtExprMutator {
return std::move(decl);
}

Stmt VisitStmt_(const BlockNode* op) final {
Block block = GetRef<Block>(op);

// The BlockNode is the point of definition for the IterVar
// instances. These re-defines must be present before visiting
// the body of the BlockNode.
std::vector<ScopedRedefine> redefines;
Array<IterVar> iter_vars = op->iter_vars.Map([&](IterVar iter_var) {
if (defined_.count(iter_var->var.get())) {
redefines.emplace_back(this, iter_var->var);
iter_var.CopyOnWrite()->var = redefines.back().new_var;
} else {
defined_.insert(iter_var->var.get());
}
return iter_var;
});
Array<BufferRegion> reads =
block->reads.Map([&](const auto& region) { return VisitBufferAccess(region); });
Array<BufferRegion> writes =
block->writes.Map([&](const auto& region) { return VisitBufferAccess(region); });

if (!reads.same_as(block->reads) || !writes.same_as(block->writes) ||
!iter_vars.same_as(op->iter_vars)) {
auto write_ptr = block.CopyOnWrite();
write_ptr->reads = reads;
write_ptr->writes = writes;
write_ptr->iter_vars = iter_vars;
}

Stmt output = Downcast<Block>(StmtExprMutator::VisitStmt_(block.get()));

while (redefines.size()) redefines.pop_back();

return output;
}

template <typename Node>
Node VisitBufferAccess(Node node) {
Buffer new_buf = GetRemappedBuffer(node->buffer);
Expand Down
81 changes: 81 additions & 0 deletions tests/python/relax/test_transform_legalize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,5 +356,86 @@ def main(
tvm.ir.assert_structural_equal(AfterFirstIter, AfterSecondIter)


def test_legalize_with_vdevice():
"""Legalization may generate kernels for multiple targets

This is a regression test. In previous implementations, Relax
expressions whose argument types differed only by their `vdevice`
would be legalized to use the same `PrimFunc`.

"""

@I.ir_module
class Before:
I.module_global_infos({"vdevice": [I.vdevice("llvm")]})

@R.function
def func_cuda(A: R.Tensor([32, 32], "float32"), B: R.Tensor([32, 32], "float32")):
C = R.add(A, B)
return C

@R.function
def func_llvm(
A: R.Tensor([32, 32], "float32", "llvm"), B: R.Tensor([32, 32], "float32", "llvm")
):
C = R.add(A, B)
return C

@I.ir_module
class Expected:
I.module_global_infos({"vdevice": [I.vdevice("llvm")]})

@R.function
def func_cuda(
A: R.Tensor((32, 32), dtype="float32"),
B: R.Tensor((32, 32), dtype="float32"),
):
cls = Expected
C = R.call_tir(cls.add, (A, B), out_sinfo=R.Tensor((32, 32), dtype="float32"))
return C

@T.prim_func(private=True)
def add(
A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
for iters in T.grid(T.int64(32), T.int64(32)):
with T.block("T_add"):
ax0, ax1 = T.axis.remap("SS", iters)
C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1]

@R.function
def func_llvm(
A: R.Tensor((32, 32), dtype="float32", vdevice="llvm"),
B: R.Tensor((32, 32), dtype="float32", vdevice="llvm"),
):
cls = Expected
C = R.call_tir(
cls.add_llvm,
(A, B),
out_sinfo=R.Tensor((32, 32), dtype="float32", vdevice="llvm"),
)
return C

@T.prim_func(private=True)
def add_llvm(
A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
):
T.func_attr({"target": T.target("llvm"), "tir.noalias": T.bool(True)})
for iters in T.grid(T.int64(32), T.int64(32)):
with T.block("T_add"):
ax0, ax1 = T.axis.remap("SS", iters)
C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1]

with tvm.target.Target("cuda"):
After = tvm.relax.transform.LegalizeOps()(Before)

tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
tvm.testing.main()