Skip to content

Commit db09ba6

Browse files
committed
- add unit test
1 parent 635f45a commit db09ba6

File tree

8 files changed

+287
-59
lines changed

8 files changed

+287
-59
lines changed

src/printer/relay_text_printer.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,16 @@ Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) {
6464
// default annotations
6565
if (annotate_ == nullptr) {
6666
if ((expr.as<ConstantNode>() || expr.as<CallNode>() || expr.as<VarNode>() ||
67-
expr.as<FunctionNode>() || expr.as<TupleNode>() || expr.as<TupleGetItemNode>()) &&
67+
expr.as<GlobalVarNode>() || expr.as<FunctionNode>() || expr.as<TupleNode>() ||
68+
expr.as<TupleGetItemNode>()) &&
6869
(expr->checked_type_.defined() || expr->span.defined())) {
6970
doc << " /*";
7071
if (expr->checked_type_.defined()) {
7172
doc << " ty=" << Print(expr->checked_type());
7273
}
74+
if (!expr->virtual_device()->IsFullyUnconstrained()) {
75+
doc << " virtual_device=" << Print(expr->virtual_device());
76+
}
7377
if (expr->span.defined()) {
7478
doc << " span=" << PrintSpan(expr->span);
7579
}

src/relay/backend/aot_executor_codegen.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,9 +1064,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
10641064

10651065
mod = transform::ToANormalForm()(mod);
10661066

1067-
IRModule lowered_mod = tec::LowerTEPass(
1068-
mod_name,
1069-
[this, workspace_byte_alignment](BaseFunc func) {
1067+
IRModule lowered_mod =
1068+
tec::LowerTE(mod_name, config_, [this, workspace_byte_alignment](BaseFunc func) {
10701069
// We need to maintain the constant map for external
10711070
// functions so we pass this processing function which
10721071
// allows us to process each function as we lower it.
@@ -1078,8 +1077,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
10781077
// execute as a further pass, instead writing data to the
10791078
// lowering process directly.
10801079
tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment);
1081-
},
1082-
config_)(mod);
1080+
})(mod);
10831081

10841082
auto lowered_main = lowered_mod->Lookup("main");
10851083
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());

src/relay/backend/graph_executor_codegen.cc

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -217,22 +217,19 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
217217
mod = WithAttr(mod, "main_func_info", func_info);
218218
}
219219

220-
IRModule lowered_mod = tec::LowerTEPass(
221-
mod_name_,
222-
[this](BaseFunc func) {
223-
// We need to maintain the constant map for external
224-
// functions so we pass this processing function which
225-
// allows us to process each function as we lower it.
226-
if (func->GetAttr<String>(attr::kCompiler).defined()) {
227-
UpdateConstants(func, &params_);
228-
}
220+
IRModule lowered_mod = tec::LowerTE(mod_name_, config_, [this](BaseFunc func) {
221+
// We need to maintain the constant map for external
222+
// functions so we pass this processing function which
223+
// allows us to process each function as we lower it.
224+
if (func->GetAttr<String>(attr::kCompiler).defined()) {
225+
UpdateConstants(func, &params_);
226+
}
229227

230-
// TODO(@areusch, @jroesch): We should refactor this to
231-
// execute as a further pass, instead writing data to the
232-
// lowering process directly.
233-
tec::UpdateFunctionMetadata(func, this->function_metadata_);
234-
},
235-
config_)(mod);
228+
// TODO(@areusch, @jroesch): We should refactor this to
229+
// execute as a further pass, instead writing data to the
230+
// lowering process directly.
231+
tec::UpdateFunctionMetadata(func, this->function_metadata_);
232+
})(mod);
236233

237234
Optional<backend::FunctionInfo> main_func_info =
238235
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");

src/relay/backend/interpreter.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -960,8 +960,7 @@ IRModule Prepare(IRModule mod, const CompilationConfig& config) {
960960
// eta expand to support constructors in argument position.
961961
transform::EtaExpand(
962962
/*expand_constructor=*/true, /*expand_global_var=*/false),
963-
transform::InferType(),
964-
tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op */ }, config)});
963+
transform::InferType(), tec::LowerTE(/*module_name=*/"intrp", config)});
965964

966965
transform::PassContext pass_ctx = transform::PassContext::Current();
967966
With<transform::PassContext> ctx(pass_ctx);

src/relay/backend/te_compiler.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1152,6 +1152,7 @@ void UpdateFunctionMetadata(BaseFunc func,
11521152
function_metadata.Set(prim_fn_var.value()->name_hint, fi);
11531153
}
11541154

1155+
/*! \brief Main lowering driving. */
11551156
IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn,
11561157
CompilationConfig config) {
11571158
TECompiler compiler(module);
@@ -1269,7 +1270,7 @@ Map<Target, IRModule> GetPerTargetModules(IRModule mod) {
12691270
return per_target_modules;
12701271
}
12711272

1272-
Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig complilation_config) {
1273+
Pass LowerTE(String module_name, CompilationConfig complilation_config, ProcessFn process_fn) {
12731274
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module,
12741275
PassContext ctx) {
12751276
return LowerTE(module, module_name, process_fn, complilation_config);
@@ -1280,6 +1281,12 @@ Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig com
12801281
tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {"InferType"}), InferType(),
12811282
tvm::tir::transform::ExtractPrimFuncConstants()});
12821283
}
1284+
1285+
TVM_REGISTER_GLOBAL("relay.tec.LowerTE")
1286+
.set_body_typed([](String module_name, CompilationConfig compilation_config) {
1287+
return LowerTE(std::move(module_name), std::move(compilation_config));
1288+
});
1289+
12831290
} // namespace tec
12841291
} // namespace relay
12851292
} // namespace tvm

src/relay/backend/te_compiler.h

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -173,36 +173,22 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, const Compila
173173
*/
174174
Map<Target, IRModule> GetPerTargetModules(IRModule mod);
175175

176-
/*! \brief Lower an IRModule's primitive functions to TIR.
177-
*
178-
* This is the "back half" of the Relay compiler which lowers "primitive functions"
179-
* to TE expressions, schedules them, and then to TIR.
180-
*
181-
* \param module The IRModule.
182-
* \param memory_plan The memory plan used during lowering
183-
* \param module_name The name of this module
184-
* \param process_fn Callback allowing one-level up code generators to process
185-
* each function that we lower
186-
* \return The lowered module, see above.
187-
*/
188-
IRModule LowerTE(
189-
const IRModule& module, backend::StaticMemoryPlan memory_plan, const String& module_name,
190-
ProcessFn process_fn = [](BaseFunc f) {});
176+
inline void DefaultProcessFn(BaseFunc) {}
191177

192178
/*!
193179
* \brief Pass to lower an IRModule's primitive functions to TIR.
194180
*
195181
* This is the "back half" of the Relay compiler which lowers "primitive functions"
196-
* to TE expressions, schedules them, and then to TIR. It annotates all functions
197-
* with their target.
182+
* to TE expressions, schedules them, and emits PrimFuncs.
198183
*
199-
* \param module_name The name of this module
200-
* \param process_fn Callback allowing one-level up code generators to process
201-
* each function that we lower
184+
* \param module_name The name of this module, used as a prefix for generated globals.
202185
* \param config All available targets.
186+
* \param process_fn Callback allowing one-level up code generators to process
187+
* each function that we lower (default is no-op).
203188
* \returns The pass which lowers primitive functions to TIR
204189
*/
205-
transform::Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig config);
190+
transform::Pass LowerTE(String module_name, CompilationConfig config,
191+
ProcessFn process_fn = DefaultProcessFn);
206192

207193
} // namespace tec
208194
} // namespace relay

src/relay/backend/vm/compiler.cc

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,13 +1039,11 @@ transform::Sequential VMCompiler::FuseAndLowerOperators(const CompilationConfig&
10391039
// Give each "primitive" Function a hash.
10401040
pass_seqs.push_back(LabelOps());
10411041
// Lower "primitive" Functions to PrimFuncs and rewrite calls.
1042-
pass_seqs.push_back(tec::LowerTEPass(/*module_name=*/"vm_mod",
1043-
[this](const BaseFunc& func) {
1044-
if (func->GetAttr<String>(attr::kCompiler).defined()) {
1045-
backend::UpdateConstants(func, &params_);
1046-
}
1047-
},
1048-
config));
1042+
pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config, [this](const BaseFunc& func) {
1043+
if (func->GetAttr<String>(attr::kCompiler).defined()) {
1044+
backend::UpdateConstants(func, &params_);
1045+
}
1046+
}));
10491047
// Since lowered functions are bound in the IRModule, we can now eliminate any unused
10501048
// let-bound functions.
10511049
pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));
@@ -1090,13 +1088,11 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
10901088
pass_seqs.push_back(transform::LabelOps());
10911089

10921090
// Lower all functions annotated as "primitive" by FuseOps.
1093-
pass_seqs.push_back(tec::LowerTEPass(/*module_name=*/"vm_mod",
1094-
[this](const BaseFunc& func) {
1095-
if (func->GetAttr<String>(attr::kCompiler).defined()) {
1096-
backend::UpdateConstants(func, &params_);
1097-
}
1098-
},
1099-
config_));
1091+
pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config_, [this](const BaseFunc& func) {
1092+
if (func->GetAttr<String>(attr::kCompiler).defined()) {
1093+
backend::UpdateConstants(func, &params_);
1094+
}
1095+
}));
11001096

11011097
// Since lowered functions are bound in the IRModule, we can now eliminate any unused
11021098
// let-bound functions.

0 commit comments

Comments
 (0)