Skip to content

Commit ec24ae6

Browse files
authored
[BYOC] RelayToTIR custom codegen passes can still depend on dynamic shape functions (#11619)
In #11474 I got ready to switch CUTLASS from function-at-a-time to IRModule-at-a-time compilation. However my approach didn't handle dynamic shape functions, so I adjust it here. The idea is still that such passes will leave behind calls to 'extern' functions. However, converting those calls to 'call_lowered' form in MarkCompilerFunctionsAsExtern is too soon since only the TECompiler knows how to capture all the attributes necessary to support dynamic shape functions. So stop doing that in MarkCompilerFunctionsAsExtern and instead support this case properly in the TECompiler. While there try to chip away at the chronic lack of structure in te_compiler.cc. Every little bit helps. Add a basic unit test.
1 parent 53d163c commit ec24ae6

File tree

10 files changed

+503
-228
lines changed

10 files changed

+503
-228
lines changed

src/relay/backend/aot_executor_codegen.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,9 +1064,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
10641064

10651065
mod = transform::ToANormalForm()(mod);
10661066

1067-
IRModule lowered_mod = tec::LowerTEPass(
1068-
mod_name,
1069-
[this, workspace_byte_alignment](BaseFunc func) {
1067+
IRModule lowered_mod =
1068+
tec::LowerTE(mod_name, config_, [this, workspace_byte_alignment](BaseFunc func) {
10701069
// We need to maintain the constant map for external
10711070
// functions so we pass this processing function which
10721071
// allows us to process each function as we lower it.
@@ -1078,8 +1077,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
10781077
// execute as a further pass, instead writing data to the
10791078
// lowering process directly.
10801079
tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment);
1081-
},
1082-
config_)(mod);
1080+
})(mod);
10831081

10841082
auto lowered_main = lowered_mod->Lookup("main");
10851083
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());

src/relay/backend/graph_executor_codegen.cc

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -217,22 +217,19 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
217217
mod = WithAttr(mod, "main_func_info", func_info);
218218
}
219219

220-
IRModule lowered_mod = tec::LowerTEPass(
221-
mod_name_,
222-
[this](BaseFunc func) {
223-
// We need to maintain the constant map for external
224-
// functions so we pass this processing function which
225-
// allows us to process each function as we lower it.
226-
if (func->GetAttr<String>(attr::kCompiler).defined()) {
227-
UpdateConstants(func, &params_);
228-
}
220+
IRModule lowered_mod = tec::LowerTE(mod_name_, config_, [this](BaseFunc func) {
221+
// We need to maintain the constant map for external
222+
// functions so we pass this processing function which
223+
// allows us to process each function as we lower it.
224+
if (func->GetAttr<String>(attr::kCompiler).defined()) {
225+
UpdateConstants(func, &params_);
226+
}
229227

230-
// TODO(@areusch, @jroesch): We should refactor this to
231-
// execute as a further pass, instead writing data to the
232-
// lowering process directly.
233-
tec::UpdateFunctionMetadata(func, this->function_metadata_);
234-
},
235-
config_)(mod);
228+
// TODO(@areusch, @jroesch): We should refactor this to
229+
// execute as a further pass, instead writing data to the
230+
// lowering process directly.
231+
tec::UpdateFunctionMetadata(func, this->function_metadata_);
232+
})(mod);
236233

237234
Optional<backend::FunctionInfo> main_func_info =
238235
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");

src/relay/backend/interpreter.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -960,8 +960,7 @@ IRModule Prepare(IRModule mod, const CompilationConfig& config) {
960960
// eta expand to support constructors in argument position.
961961
transform::EtaExpand(
962962
/*expand_constructor=*/true, /*expand_global_var=*/false),
963-
transform::InferType(),
964-
tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op */ }, config)});
963+
transform::InferType(), tec::LowerTE(/*module_name=*/"intrp", config)});
965964

966965
transform::PassContext pass_ctx = transform::PassContext::Current();
967966
With<transform::PassContext> ctx(pass_ctx);

src/relay/backend/te_compiler.cc

Lines changed: 221 additions & 108 deletions
Large diffs are not rendered by default.

src/relay/backend/te_compiler.h

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
*/
1919

2020
/*!
21-
* \file relay/backend/tir_compiler.h
22-
* * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns.
21+
* \file relay/backend/te_compiler.h
22+
* \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns.
2323
*
2424
*
2525
* This represents the new design of the Relay compilation flow and will replace the interface
@@ -173,36 +173,22 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, const Compila
173173
*/
174174
Map<Target, IRModule> GetPerTargetModules(IRModule mod);
175175

176-
/*! \brief Lower an IRModule's primitive functions to TIR.
177-
*
178-
* This is the "back half" of the Relay compiler which lowers "primitive functions"
179-
* to TE expressions, schedules them, and then to TIR.
180-
*
181-
* \param module The IRModule.
182-
* \param memory_plan The memory plan used during lowering
183-
* \param module_name The name of this module
184-
* \param process_fn Callback allowing one-level up code generators to process
185-
* each function that we lower
186-
* \return The lowered module, see above.
187-
*/
188-
IRModule LowerTE(
189-
const IRModule& module, backend::StaticMemoryPlan memory_plan, const String& module_name,
190-
ProcessFn process_fn = [](BaseFunc f) {});
176+
inline void DefaultProcessFn(BaseFunc) {}
191177

192178
/*!
193179
* \brief Pass to lower an IRModule's primitive functions to TIR.
194180
*
195181
* This is the "back half" of the Relay compiler which lowers "primitive functions"
196-
* to TE expressions, schedules them, and then to TIR. It annotates all functions
197-
* with their target.
182+
* to TE expressions, schedules them, and emits PrimFuncs.
198183
*
199-
* \param module_name The name of this module
200-
* \param process_fn Callback allowing one-level up code generators to process
201-
* each function that we lower
184+
* \param module_name The name of this module, used as a prefix for generated globals.
202185
* \param config All available targets.
186+
* \param process_fn Callback allowing one-level up code generators to process
187+
* each function that we lower (default is no-op).
203188
* \returns The pass which lowers primitive functions to TIR
204189
*/
205-
transform::Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig config);
190+
transform::Pass LowerTE(String module_name, CompilationConfig config,
191+
ProcessFn process_fn = DefaultProcessFn);
206192

207193
} // namespace tec
208194
} // namespace relay

src/relay/backend/vm/compiler.cc

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,13 +1040,11 @@ transform::Sequential VMCompiler::FuseAndLowerOperators(const CompilationConfig&
10401040
// Give each "primitive" Function a hash.
10411041
pass_seqs.push_back(LabelOps());
10421042
// Lower "primitive" Functions to PrimFuncs and rewrite calls.
1043-
pass_seqs.push_back(tec::LowerTEPass(/*module_name=*/"vm_mod",
1044-
[this](const BaseFunc& func) {
1045-
if (func->GetAttr<String>(attr::kCompiler).defined()) {
1046-
backend::UpdateConstants(func, &params_);
1047-
}
1048-
},
1049-
config));
1043+
pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config, [this](const BaseFunc& func) {
1044+
if (func->GetAttr<String>(attr::kCompiler).defined()) {
1045+
backend::UpdateConstants(func, &params_);
1046+
}
1047+
}));
10501048
// Since lowered functions are bound in the IRModule, we can now eliminate any unused
10511049
// let-bound functions.
10521050
pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));
@@ -1091,13 +1089,11 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
10911089
pass_seqs.push_back(transform::LabelOps());
10921090

10931091
// Lower all functions annotated as "primitive" by FuseOps.
1094-
pass_seqs.push_back(tec::LowerTEPass(/*module_name=*/"vm_mod",
1095-
[this](const BaseFunc& func) {
1096-
if (func->GetAttr<String>(attr::kCompiler).defined()) {
1097-
backend::UpdateConstants(func, &params_);
1098-
}
1099-
},
1100-
config_));
1092+
pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config_, [this](const BaseFunc& func) {
1093+
if (func->GetAttr<String>(attr::kCompiler).defined()) {
1094+
backend::UpdateConstants(func, &params_);
1095+
}
1096+
}));
11011097

11021098
// Since lowered functions are bound in the IRModule, we can now eliminate any unused
11031099
// let-bound functions.

src/relay/transforms/compiler_function_utils.cc

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -81,42 +81,6 @@ class Outliner : public MixedModeMutator {
8181
IRModule mod_;
8282
};
8383

84-
/*!
85-
* \brief Rewrite calls to global "Compiler" functions to use the 'call_lowered' convention.
86-
*/
87-
class CallRewriter : public MixedModeMutator {
88-
public:
89-
CallRewriter(std::string compiler_filter, IRModule mod)
90-
: compiler_filter_(std::move(compiler_filter)), mod_(std::move(mod)) {}
91-
92-
Expr Rewrite_(const CallNode* pre, const Expr& post) final {
93-
Call new_call = Downcast<Call>(post);
94-
if (const auto* global_var_node = new_call->op.as<GlobalVarNode>()) {
95-
if (const auto* function_node =
96-
mod_->Lookup(GetRef<GlobalVar>(global_var_node)).as<FunctionNode>()) {
97-
Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler);
98-
if (opt_compiler.defined() &&
99-
(compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) {
100-
Optional<String> opt_global_symbol =
101-
function_node->GetAttr<String>(tvm::attr::kGlobalSymbol);
102-
ICHECK(opt_global_symbol.defined());
103-
GlobalVar global_symbol = mod_->GetGlobalVar(opt_global_symbol.value());
104-
CallLoweredAttrs attrs;
105-
attrs.metadata.Set("relay_attrs", new_call->attrs);
106-
return CallLowered(global_symbol, new_call->args, attrs, new_call->span);
107-
}
108-
}
109-
}
110-
return post;
111-
}
112-
113-
private:
114-
/*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */
115-
std::string compiler_filter_;
116-
/*! \brief Module being rewritten. */
117-
IRModule mod_;
118-
};
119-
12084
} // namespace
12185

12286
GlobalSymbolCache::~GlobalSymbolCache() = default;
@@ -169,20 +133,6 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) {
169133
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
170134
[compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) {
171135
IRModule output_mod = mod->ShallowCopy();
172-
173-
// First pass, rewrite the calls.
174-
// We have to do this before marking functions as 'extern' to know which calls to rewrite!
175-
for (const auto& kv : mod->functions) {
176-
if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
177-
Expr new_body =
178-
CallRewriter(compiler_filter, output_mod).VisitExpr(function_node->body);
179-
Function new_function =
180-
WithFields(GetRef<Function>(function_node), /*opt_params=*/{}, new_body);
181-
output_mod->Update(kv.first, new_function);
182-
}
183-
}
184-
185-
// Second pass, mark functions as 'extern'.
186136
for (const auto& kv : mod->functions) {
187137
if (const auto* function_node = kv.second.as<FunctionNode>()) {
188138
Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler);
@@ -197,7 +147,6 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) {
197147
}
198148
}
199149
}
200-
201150
return output_mod;
202151
};
203152

src/relay/transforms/compiler_function_utils.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,8 @@
4343
*
4444
* - \p MarkCompilerFunctionsAsExtern will replace global functions with a matching "Compiler"
4545
* attribute with the same function with just an "Extern" attribute, signalling the function
46-
* has been dealt with. Calls to such functions will be rewritten to use the 'call_lowered'
47-
* calling convention. Can be used after lowering to cleanup the IRModule.
48-
*
49-
* Note that the above behaviour is hard coded within the TECompiler, but is only available to
50-
* external codegen using the Function-at-a-time "relay.ext.toolchain" extension point.
46+
* has been dealt with. However calls to such functions will be left unchanged. Can be used
47+
* after lowering to cleanup the IRModule.
5148
*/
5249

5350
#ifndef TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_
@@ -118,8 +115,8 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co
118115

119116
/*!
120117
* \brief A pass to mark all global functions which have a "Compiler" attribute matching
121-
* compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute, and
122-
* rewrite all calls to such functions to use the 'call_lowered' calling convention.
118+
* compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute.
119+
* Calls to such functions are not changed.
123120
*
124121
* If \p compiler_filter is non-empty only functions with that as their attribute value are
125122
* outlined.

0 commit comments

Comments
 (0)