Skip to content

Commit c65b2c8

Browse files
committed
[Relay] Plumb external codegen target via Target.current() for all external codegen paths
(See https://discuss.tvm.apache.org/t/byoc-supporting-cutlass-byoc-with-collage/12796/6 for context, which in turn is part of Collage (https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md). We want both old-style (via relay.ext.$toolchain) and new-style (via "RelayToTIR" Pass attribute on target kind) external codegen to be able to access the current 'external codegen' Target instance via Target.current(). - For old-style, plumb the true Target through TEComplier and push it on the context stack before calling relay.ext.$toolchain. - For new-style, pass the CompilationConfig to the RelayToTIRTargetHook pass, make the jump from "Compiler" attribute value to Target via the new CompilationConfig::FindPrimitiveTargetForKind method, and push on the stack before invoking the custom "RelayToTIR" pass. While working on this discovered RelayToTIRTargetHook was incompatible with the VM's compilation flow since RelayToTIRTargetHook assumes all "Compiler" attributed functions are inlined. Generalize it to support both inline and global function styles. Extend Target::IsExternalCodegen to recognize target kinds with "RelayToTIR" attributes as external. Update target hooks unit test to exercise new support for outline-style, picking up the current target, and compiling via the VM.
1 parent d146777 commit c65b2c8

File tree

20 files changed

+409
-155
lines changed

20 files changed

+409
-155
lines changed

include/tvm/relay/transform.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,11 @@ TVM_DLL Pass SimplifyExpr();
464464
/*!
465465
* \brief Run any registered RelayToTIR passes registered on the functions in a module.
466466
*
467+
* \param config All available targets.
468+
*
467469
* \return The pass.
468470
*/
469-
TVM_DLL Pass RelayToTIRTargetHook();
471+
TVM_DLL Pass RelayToTIRTargetHook(CompilationConfig config);
470472

471473
/*!
472474
* \brief A pass for manifesting explicit memory allocations and rewriting

include/tvm/target/target_kind.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,16 @@ namespace attr {
402402
* See also \p Target::IsExternalCodegenFor
403403
*/
404404
constexpr const char* kIsExternalCodegen = "is_external_codegen";
405+
406+
/*!
407+
* \brief A \p TargetKind attribute of type \p FTVMRelayToTIR. If set, then the target kind name
408+
* also corresponds to an external codegen 'compiler' name, and the bound value is a \p Pass
409+
* to apply before the TVM lowering.
410+
*
411+
* See also \p Target::IsExternalCodegenFor
412+
*/
413+
constexpr const char* kRelayToTIR = "RelayToTIR";
414+
405415
} // namespace attr
406416

407417
/*!

src/relay/backend/aot_executor_codegen.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
10791079
// lowering process directly.
10801080
tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment);
10811081
},
1082-
config_->host_virtual_device)(mod);
1082+
config_)(mod);
10831083

10841084
auto lowered_main = lowered_mod->Lookup("main");
10851085
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());

src/relay/backend/contrib/cmsisnn/target.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ tvm::transform::Pass RelayToTIR();
3131
runtime::Module TIRToRuntime(IRModule mod, Target target);
3232

3333
TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU)
34-
.set_attr<FTVMRelayToTIR>("RelayToTIR", RelayToTIR())
34+
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
3535
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime);
3636

3737
} // namespace cmsisnn

src/relay/backend/contrib/ethosu/codegen.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
320320

321321
TVM_REGISTER_TARGET_KIND("ethos-u", kDLCPU)
322322
.set_attr<Bool>("use_device_api", Bool(true))
323-
.set_attr<FTVMRelayToTIR>("RelayToTIR", RelayToTIR())
323+
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
324324
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime);
325325

326326
} // namespace ethosu

src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc

Lines changed: 145 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,35 @@
2828
#include <tvm/tir/op.h>
2929

3030
#include "../../../op/call/call.h"
31+
#include "tvm/tir/function.h"
3132

3233
namespace tvm {
3334
namespace relay {
3435
namespace contrib {
3536
namespace example_target_hooks {
3637

38+
namespace {
39+
40+
/*!
41+
* \brief An example mutator for a "RelayToTIR" custom pass. Replaces every call to a Relay
42+
* Function with "external_symbol" attribute of "replace_add_with_subtract" with a call to a
43+
* TIR PrimFunc implementing subtraction.
44+
*
45+
* Illustrates six aspects a custom 'lowering' style pass may need to account for:
46+
* - Lowerable functions can appear inline as calls ops, bound to let-bound variables, or as
47+
* global functions.
48+
* - Let-bound lowerable functions should be inlined on-the-fly since after processing the
49+
* let-binding is no longer required.
50+
* - There may be multiple calls to the same lowerable function. All calls need to be
51+
* rewritten, even though the function itself need be rewritten only once.
52+
* - GlobalVars must be shared between all calls and the new definition itself.
53+
* - Calls to lowered functions must use the "call_lowered" calling convention.
54+
* - The Target::Current() may hold an instance of the TargetKind from which the custom Pass
55+
* was extracted.
56+
*
57+
* Though not illustrated here, it is also valid for a "RelayToTIR" custom pass to add
58+
* runtime::Modules to the output IRModule's "external_mods" attribute.
59+
*/
3760
class ConvertAddToSubtract : public MixedModeMutator {
3861
public:
3962
explicit ConvertAddToSubtract(IRModule ir_module, Target host_target)
@@ -56,51 +79,105 @@ class ConvertAddToSubtract : public MixedModeMutator {
5679
return tir::BufferLoad(buffer, {index});
5780
}
5881

59-
void ReplaceAddWithSubtractPrimFunc(const GlobalVar& new_global_var, const Function& func) {
60-
tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x");
61-
tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y");
62-
tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32));
82+
GlobalVar ReplaceAddWithSubtractPrimFunc(const Function& func) {
83+
auto func_name = func->GetAttr<String>(::tvm::attr::kGlobalSymbol);
84+
ICHECK(func_name.defined());
85+
86+
// --------------------------------------------------------------------------------------------
87+
// Cases:
88+
// - Inline function:
89+
// - First encounter: create global var, rewrite to PrimFunc, add binding, replace call.
90+
// - Thereafter (via object sharing): discover global var already in module, replace call
91+
// - Global function:
92+
// - func_name == global_var->name_hint
93+
// - First encounter: rewrite to PrimFunc and update binding, replace call
94+
// - Thereafter (via global var): Just replace call
95+
// - func_name != global_var->name_hint
96+
// - First encounter: create global var, rewrite to PrimFunc, add binding, replace call
97+
// (The original Relay function should also be tagged as 'extern', ie given attribute
98+
// "ExternalSymbol".)
99+
// - Thereafter (via global var): discover global var already in module, replace call
100+
// --------------------------------------------------------------------------------------------
101+
102+
// If necessary, introduce a new global var to map the function to and copy the source type
103+
// over for InferType.
104+
GlobalVar global_var;
105+
bool need_rewriting;
106+
if (ir_module_->ContainGlobalVar(func_name.value())) {
107+
global_var = ir_module_->GetGlobalVar(func_name.value());
108+
// Only rewrite to a PrimFunc if the global definition is still a Relay function.
109+
need_rewriting = ir_module_->Lookup(global_var)->IsInstance<FunctionNode>();
110+
} else {
111+
global_var = GlobalVar(func_name.value());
112+
global_var->checked_type_ = func->checked_type();
113+
need_rewriting = true;
114+
}
115+
116+
// For illustration only, check if the current target matches the example_target_hook kind,
117+
// and if so extract the example attribute value.
118+
int64_t example_attribute_value = 0;
119+
Optional<Target> opt_current_target = Target::Current();
120+
if (opt_current_target.defined() &&
121+
opt_current_target.value()->kind->name == "example_target_hook") {
122+
example_attribute_value =
123+
opt_current_target.value()->GetAttr<Integer>("example_attribute").value()->value;
124+
}
63125

64-
tir::Var x_var("x", DataType::Handle());
65-
tir::Var y_var("y", DataType::Handle());
66-
tir::Var out_var("out", DataType::Handle());
126+
if (need_rewriting) {
127+
// The called function is still in Relay form. Convert to TIR.
128+
tir::Buffer x_buffer = tir::decl_buffer({8}, DataType::Float(32), "x");
129+
tir::Buffer y_buffer = tir::decl_buffer({8}, DataType::Float(32), "y");
130+
tir::Buffer out_buffer = tir::decl_buffer({8}, DataType::Float(32));
67131

68-
Map<String, ObjectRef> dict_attrs;
69-
dict_attrs.Set("global_symbol", new_global_var->name_hint);
70-
dict_attrs.Set("tir.noalias", Bool(true));
132+
tir::Var x_var("x", DataType::Handle());
133+
tir::Var y_var("y", DataType::Handle());
134+
tir::Var out_var("out", DataType::Handle());
71135

72-
te::Var index("index", DataType::Int(32));
73-
tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index));
74-
tir::Stmt math_body = tir::BufferStore(out_buffer, indexed_sub, {index});
75-
tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body);
136+
Map<String, ObjectRef> dict_attrs;
137+
dict_attrs.Set("global_symbol", global_var->name_hint);
138+
dict_attrs.Set("tir.noalias", Bool(true));
76139

77-
Map<tir::Var, tir::Buffer> buffer_map = {
78-
{x_var, x_buffer},
79-
{y_var, y_buffer},
80-
{out_var, out_buffer},
81-
};
140+
te::Var index("index", DataType::Int(32));
141+
tir::Sub indexed_sub = tir::Sub(LoadIndex(x_buffer, index), LoadIndex(y_buffer, index));
142+
if (example_attribute_value > 0) {
143+
// For illustration only, fold the example attribute into the result.
144+
indexed_sub = tir::Sub(indexed_sub, FloatImm(DataType::Float(32),
145+
static_cast<double>(example_attribute_value)));
146+
}
82147

83-
tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(),
84-
buffer_map, {}, DictAttrs(dict_attrs));
148+
tir::Stmt math_body = tir::BufferStore(out_buffer, indexed_sub, {index});
149+
tir::Stmt math_loop = tir::For(index, 0, 8, tir::ForKind::kSerial, math_body);
85150

86-
// Switch to TIRToRuntime hook for testing
87-
Bool tir_to_runtime = func->GetAttr<Bool>("tir_to_runtime").value_or(Bool(false));
88-
if (tir_to_runtime) {
89-
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, custom_target_);
90-
} else {
91-
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_);
151+
Map<tir::Var, tir::Buffer> buffer_map = {
152+
{x_var, x_buffer},
153+
{y_var, y_buffer},
154+
{out_var, out_buffer},
155+
};
156+
157+
tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(),
158+
buffer_map, {}, DictAttrs(dict_attrs));
159+
160+
// Switch to TIRToRuntime hook for testing
161+
Bool tir_to_runtime = func->GetAttr<Bool>("tir_to_runtime").value_or(Bool(false));
162+
if (tir_to_runtime) {
163+
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, custom_target_);
164+
} else {
165+
replacement_func = WithAttr(replacement_func, ::tvm::attr::kTarget, host_target_);
166+
}
167+
168+
ir_module_->Update(global_var, replacement_func); // Will Add if global_var is new.
92169
}
93170

94-
ir_module_->Add(new_global_var, replacement_func);
171+
return global_var;
95172
}
96173

97174
Expr VisitExpr_(const LetNode* op) final {
98175
auto pre_visit = [this](const LetNode* op) {
99176
Expr var = this->VisitExpr(op->var);
100177
Expr value = this->VisitExpr(op->value);
101178

102-
// Outlineable function no longer needs let binding
103-
if (this->CanLowerExpr(value)) {
179+
if (AsLowerableFunction(value)) {
180+
// Inline on-the-fly if the let-bound value is lowerable.
104181
this->memo_[var] = value;
105182
}
106183
};
@@ -110,8 +187,8 @@ class ConvertAddToSubtract : public MixedModeMutator {
110187
Expr body = this->VisitExpr(op->body);
111188
auto expr = GetRef<Expr>(op);
112189

113-
// Drop the let binding
114-
if (this->CanLowerExpr(value)) {
190+
if (AsLowerableFunction(value)) {
191+
// The let binding is no longer needed since inlined on-the-fly above.
115192
this->memo_[expr] = this->VisitExpr(op->body);
116193
} else {
117194
Var var = Downcast<Var>(this->VisitExpr(op->var));
@@ -126,39 +203,49 @@ class ConvertAddToSubtract : public MixedModeMutator {
126203
return memo_[GetRef<Expr>(op)];
127204
}
128205

129-
bool CanLowerExpr(const Expr& expr) {
130-
const auto* func = expr.as<FunctionNode>();
131-
if (func == nullptr) {
132-
return false;
133-
}
134-
auto func_name = func->GetAttr<String>(::tvm::attr::kGlobalSymbol);
135-
if (!func_name.defined()) {
136-
return false;
206+
const FunctionNode* AsLowerableFunction(const Expr& expr) {
207+
if (const auto* function_node = expr.as<FunctionNode>()) {
208+
auto func_name = function_node->GetAttr<String>(::tvm::attr::kGlobalSymbol);
209+
if (!func_name.defined()) {
210+
return nullptr;
211+
}
212+
if (func_name != "replace_add_with_subtract") {
213+
return nullptr;
214+
}
215+
return function_node;
216+
} else if (const auto* global_var_node = expr.as<GlobalVarNode>()) {
217+
return AsLowerableFunction(ir_module_->Lookup(GetRef<GlobalVar>(global_var_node)));
218+
} else {
219+
return nullptr;
137220
}
138-
if (func_name != "replace_add_with_subtract") {
139-
return false;
221+
}
222+
223+
const GlobalVarNode* AsAlreadyLoweredFunction(const Expr& expr) {
224+
if (const auto* global_var_node = expr.as<GlobalVarNode>()) {
225+
if (ir_module_->Lookup(GetRef<GlobalVar>(global_var_node)).as<tir::PrimFuncNode>()) {
226+
return global_var_node;
227+
}
140228
}
141-
return true;
229+
return nullptr;
142230
}
143231

144232
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
145-
if (const CallNode* call = post.as<CallNode>()) {
146-
if (CanLowerExpr(call->op)) {
147-
auto* func = call->op.as<FunctionNode>();
148-
auto func_name = func->GetAttr<String>(::tvm::attr::kGlobalSymbol);
149-
150-
// Introduce a new global var to map the function to and copy the source type
151-
// over for InferType
152-
GlobalVar new_global_var(func_name.value());
153-
new_global_var->checked_type_ = func->checked_type();
154-
ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef<Function>(func));
155-
233+
if (const auto* call = post.as<CallNode>()) {
234+
GlobalVar new_op;
235+
if (const auto* function_node = AsLowerableFunction(call->op)) {
236+
// Add or replace the function with a PrimFunc.
237+
new_op = ReplaceAddWithSubtractPrimFunc(GetRef<Function>(function_node));
238+
} else if (const auto* global_var_node = AsAlreadyLoweredFunction(call->op)) {
239+
// The function has already been rewritten, so we just need to update the call.
240+
new_op = GetRef<GlobalVar>(global_var_node);
241+
}
242+
if (new_op.defined()) {
156243
// Since we are replacing the Relay function with a call to a TIR function, we must use
157244
// the call_lowered op.
158245
CallLoweredAttrs attrs;
159246
attrs.metadata.Set("relay_attrs", call->attrs);
160247
ICHECK(call->type_args.empty()) << "lowered functions cannot be polymorphic";
161-
return CallLowered(std::move(new_global_var), call->args, std::move(attrs), call->span);
248+
return CallLowered(std::move(new_op), call->args, std::move(attrs), call->span);
162249
}
163250
}
164251

@@ -171,10 +258,12 @@ class ConvertAddToSubtract : public MixedModeMutator {
171258
Target custom_target_;
172259
};
173260

261+
} // namespace
262+
174263
transform::Pass RelayToTIR() {
175264
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
176265
[=](IRModule ir_module, transform::PassContext pass_context) {
177-
auto relay_to_tir = ConvertAddToSubtract(ir_module, Target("c"));
266+
ConvertAddToSubtract relay_to_tir(std::move(ir_module), Target("c"));
178267
return relay_to_tir.Mutate();
179268
};
180269
return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIR", {});

src/relay/backend/contrib/example_target_hooks/target.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ runtime::Module TIRToRuntime(IRModule mod, Target target);
3434

3535
TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU)
3636
.set_attr<Bool>("use_device_api", Bool(true))
37-
.set_attr<FTVMRelayToTIR>("RelayToTIR", relay::contrib::example_target_hooks::RelayToTIR())
38-
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime);
37+
.set_attr<FTVMRelayToTIR>(attr::kRelayToTIR, relay::contrib::example_target_hooks::RelayToTIR())
38+
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime)
39+
.add_attr_option<Integer>("example_attribute", Integer(0));
3940

4041
} // namespace tvm

src/relay/backend/graph_executor_codegen.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
232232
// lowering process directly.
233233
tec::UpdateFunctionMetadata(func, this->function_metadata_);
234234
},
235-
config_->host_virtual_device)(mod);
235+
config_)(mod);
236236

237237
Optional<backend::FunctionInfo> main_func_info =
238238
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");

src/relay/backend/interpreter.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,6 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
946946
* functions needed by the rewritten module.
947947
*/
948948
IRModule Prepare(IRModule mod, CompilationConfig config) {
949-
VirtualDevice host_virtual_device = config->host_virtual_device;
950949
// Run minimal transforms on module to establish invariants needed by interpreter.
951950
transform::Sequential seq(
952951
{transform::SimplifyInference(), qnn::transform::Legalize(),
@@ -962,8 +961,7 @@ IRModule Prepare(IRModule mod, CompilationConfig config) {
962961
transform::EtaExpand(
963962
/*expand_constructor=*/true, /*expand_global_var=*/false),
964963
transform::InferType(),
965-
tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op */ },
966-
std::move(host_virtual_device))});
964+
tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op */ }, config)});
967965

968966
transform::PassContext pass_ctx = transform::PassContext::Current();
969967
With<transform::PassContext> ctx(pass_ctx);

0 commit comments

Comments
 (0)