Skip to content

Commit 6adb189

Browse files
committed
Convert AOT to TECompiler
This removes the dependency on "compile_engine.h" from aot_executor_codegen.cc. This required a few changes to how AOT was operating: * AOT run_model is now based on the post lowering main_module * AOTOnDemandAllocator is ran twice to ensure SIDs are updated post-lowering * Moved to using tec::UpdateFunctionMetadata Tests are passing, but would appreciate other validation 😸
1 parent ade2d4d commit 6adb189

File tree

1 file changed

+64
-139
lines changed

1 file changed

+64
-139
lines changed

src/relay/backend/aot_executor_codegen.cc

Lines changed: 64 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,14 @@
3838
#include <string>
3939
#include <vector>
4040

41-
#include "compile_engine.h"
41+
#include "te_compiler.h"
4242
#include "utils.h"
4343

4444
namespace tvm {
4545
namespace relay {
4646
namespace backend {
4747

4848
using IntegerArray = Array<Integer>;
49-
using TargetsMap = std::unordered_map<int, Target>;
5049
using StorageMap =
5150
std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
5251

@@ -287,7 +286,6 @@ class AOTExecutorCodegen : public ExprVisitor {
287286
void CreateFuncCall(Call call, std::string func_name) {
288287
tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)};
289288
std::vector<tir::Stmt> create_func_call_stmts;
290-
291289
// Pack the inputs
292290
for (Expr arg : call->args) {
293291
if (params_by_expr_.find(arg) != params_by_expr_.end()) {
@@ -399,121 +397,21 @@ class AOTExecutorCodegen : public ExprVisitor {
399397
function_metadata_.Set(String(runtime::symbol::tvm_module_main), FunctionInfo(fi_node));
400398
}
401399

402-
/*!
403-
* \brief Update the function metadata for a given cached function and its relay
404-
* primitive function.
405-
*
406-
* \param cfunc The cached function as provided the by the compile engine
407-
* \param relay_func The source relay primitive function
408-
* \param relay_target The target associated with relay primitive function
409-
*/
410-
void UpdateFunctionMetadata(const CachedFunc& cfunc, const Function& relay_func,
411-
const Target& relay_target) {
412-
auto fi_node = make_object<FunctionInfoNode>();
413-
for (const auto& kv : cfunc->funcs->functions) {
414-
auto primfunc = Downcast<tir::PrimFunc>(kv.second);
415-
auto workspace_byte_alignment =
416-
target_host_->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
417-
Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment);
418-
Target primfunc_target = relay_target;
419-
if (primfunc->attrs->dict.count("target")) {
420-
primfunc_target = Downcast<Target>(primfunc->attrs->dict["target"]);
421-
}
422-
fi_node->workspace_sizes.Set(primfunc_target, workspace_size);
423-
// Calculating size for I/O
424-
for (auto const& param : primfunc->params) {
425-
auto p_shape = primfunc->buffer_map[param]->shape;
426-
int num_of_elements = 1;
427-
for (const auto& dim_index_expr : p_shape) {
428-
if (dim_index_expr->IsInstance<IntImmNode>()) {
429-
num_of_elements *= dim_index_expr.as<IntImmNode>()->value;
430-
} else {
431-
// If shape is dynamic, we cannot calculate workspace in compile time.
432-
num_of_elements = 0;
433-
}
434-
}
435-
int element_size = primfunc->buffer_map[param]->dtype.bytes();
436-
fi_node->io_sizes.Set(primfunc_target, element_size * num_of_elements);
437-
}
438-
fi_node->constant_sizes.Set(primfunc_target, 0);
439-
fi_node->tir_primfuncs.Set(primfunc_target, primfunc);
440-
fi_node->relay_primfuncs.Set(primfunc_target, relay_func);
441-
}
442-
function_metadata_.Set(cfunc->prim_fn_var->name_hint, FunctionInfo(fi_node));
443-
}
444-
445400
void VisitExpr_(const CallNode* op) override {
446401
// Descend the call tree
447402
for (auto arg : op->args) {
448403
VisitExpr(arg);
449404
}
450405

451-
Expr expr = GetRef<Expr>(op);
452-
Function func;
453406
if (op->op.as<OpNode>()) {
454407
LOG(FATAL) << "Operators should be transformed away; try applying"
455408
<< "the fuse_ops transformation to the expression.";
456409
} else if (op->op.as<GlobalVarNode>()) {
457-
LOG(FATAL) << "Not implemented";
458-
} else if (op->op.as<FunctionNode>()) {
459-
func = GetRef<Function>(op->op.as<FunctionNode>());
410+
GlobalVar node = GetRef<GlobalVar>(op->op.as<GlobalVarNode>());
411+
CreateFuncCall(GetRef<Call>(op), node->name_hint);
460412
} else {
461413
LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
462414
}
463-
if (!func->HasNonzeroAttr(attr::kPrimitive)) {
464-
LOG(FATAL) << "TVM only support calls to primitive functions "
465-
<< "(i.e functions composed of fusable operator invocations)";
466-
}
467-
468-
Target target;
469-
470-
// Handle external function
471-
if (func->GetAttr<String>(attr::kCompiler).defined()) {
472-
target = Target("ext_dev");
473-
CCacheKey key = CCacheKey(func, target);
474-
CachedFunc ext_func = compile_engine_->Lower(key, mod_name_);
475-
ICHECK(ext_func.defined()) << "External function is not defined.";
476-
UpdateConstants(func, &params_);
477-
478-
// Generate the TIR function call
479-
CreateFuncCall(GetRef<Call>(op), ext_func->prim_fn_var->name_hint);
480-
return;
481-
}
482-
483-
ICHECK_GE(storage_device_map_.count(expr), 0);
484-
StorageInfo& sinfo = storage_device_map_[expr];
485-
auto call_dev_type = sinfo->device_types[0];
486-
// Normal Relay Function
487-
if (targets_.size() == 1) {
488-
// homogeneous execution.
489-
const auto& it = targets_.begin();
490-
target = (*it).second;
491-
} else {
492-
// heterogeneous execution.
493-
std::string call_dev_name;
494-
if (call_dev_type == 0) {
495-
call_dev_name = "llvm";
496-
} else {
497-
call_dev_name = runtime::DeviceName(call_dev_type);
498-
}
499-
if (targets_.count(call_dev_type) == 0) {
500-
LOG(FATAL) << "No target is provided for device " << call_dev_name;
501-
}
502-
target = targets_[call_dev_type];
503-
}
504-
505-
CCacheKey key = CCacheKey(func, target);
506-
CachedFunc lowered_func = compile_engine_->Lower(key, mod_name_);
507-
508-
if (!lowered_funcs_.count(target->str())) {
509-
lowered_funcs_[target->str()] = IRModule(Map<GlobalVar, BaseFunc>({}));
510-
}
511-
lowered_funcs_[target->str()]->Update(lowered_func->funcs);
512-
// Update function metadata via looking at all primfuncs
513-
UpdateFunctionMetadata(lowered_func, func, target);
514-
515-
// Generate the TIR function call
516-
CreateFuncCall(GetRef<Call>(op), lowered_func->prim_fn_var->name_hint);
517415
}
518416

519417
void VisitExpr_(const VarNode* op) override {
@@ -598,7 +496,7 @@ class AOTExecutorCodegen : public ExprVisitor {
598496
// Create the main PrimFunc to execute the graph. Please note that
599497
// the packed function calls don't pack their arguments. The AOT
600498
// runner function needs to be legalized by the LegalizePackedCalls pass.
601-
tir::PrimFunc CreateMainFunc(unsigned int relay_params) {
499+
tir::PrimFunc CreateMainFunc(String mod_name, unsigned int relay_params) {
602500
tir::Stmt body = tir::SeqStmt(stmts_);
603501

604502
// Allocate the sids
@@ -637,7 +535,7 @@ class AOTExecutorCodegen : public ExprVisitor {
637535
// Define the PrimFunc attributes
638536
Map<String, ObjectRef> dict_attrs;
639537
String run_func_name =
640-
runtime::get_name_mangled(mod_name_, runtime::symbol::tvm_run_func_suffix);
538+
runtime::get_name_mangled(mod_name, runtime::symbol::tvm_run_func_suffix);
641539
dict_attrs.Set("global_symbol", run_func_name);
642540
dict_attrs.Set("runner_function", Bool(true));
643541

@@ -654,7 +552,7 @@ class AOTExecutorCodegen : public ExprVisitor {
654552
/*! \brief input and output variables belonging to the main function signature */
655553
Array<tir::Var> main_signature_;
656554
/*! \brief target device */
657-
TargetsMap targets_;
555+
tec::TargetMap targets_;
658556
/*! \brief target host */
659557
Target target_host_;
660558
/*!
@@ -684,35 +582,68 @@ class AOTExecutorCodegen : public ExprVisitor {
684582
/*! \brief mapping sid -> tir::Var */
685583
std::unordered_map<int, te::Var> sids_table_;
686584
/*! \brief lowered funcs */
687-
std::unordered_map<std::string, IRModule> lowered_funcs_;
688-
/*! \brief lowered funcs */
689585
Map<String, FunctionInfo> function_metadata_;
690-
/*! \brief compile engine */
691-
CompileEngine compile_engine_;
692586
/*! \brief the set of statements that make the program */
693587
std::vector<tir::Stmt> stmts_;
694588
/*! \brief the list of return sids (note that the function might return more then one output */
695589
std::vector<int> return_sid_;
696-
/*! \brief the module name we use to mangle the function names */
697-
String mod_name_;
698590

699591
public:
700-
AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host)
592+
AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host)
701593
: mod_(mod),
702594
targets_(targets),
703595
target_host_(target_host),
704-
use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false))),
705-
compile_engine_(CompileEngine::Global()) {}
596+
use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {}
706597

707598
LoweredOutput Codegen(relay::Function func, String mod_name) {
708599
auto aot_allocator = AOTOnDemandAllocator();
709600
aot_allocator.Run(func);
710601

711-
// Retrieve the storage map
712-
storage_device_map_ = aot_allocator.GetStorageMap();
713-
mod_name_ = mod_name;
602+
// Pre-lowering storage map and memory plan
603+
StorageMap initial_storage_map = aot_allocator.GetStorageMap();
604+
StaticMemoryPlan memory_plan(initial_storage_map);
605+
606+
// Build a map from each operation to device.
607+
tec::DeviceMap device_context_map;
608+
for (const auto& it : memory_plan->expr_to_storage_info) {
609+
auto expr = it.first;
610+
auto storage_info = it.second;
611+
auto device_types = storage_info->device_types;
612+
// CHECK_EQ(device_types.size(), 1);
613+
tvm::Device dev;
614+
dev.device_id = 0;
615+
dev.device_type = device_types[0];
616+
device_context_map.insert({expr, dev});
617+
}
618+
619+
// This first phase moves from implicit use of compile engine,
620+
// to instead explicitly lowering the incoming IRModule, and then
621+
// performing the preexisting AOT executor code generation phase.
622+
IRModule mod = IRModule::FromExpr(func);
623+
auto lowered_module = tec::LowerTE(
624+
mod, targets_, device_context_map, memory_plan, mod_name, [this](Function func) {
625+
// We need to maintain the constant map for external
626+
// functions so we pass this processing function which
627+
// allows us to process each function as we lower it.
628+
if (func->GetAttr<String>(attr::kCompiler).defined()) {
629+
UpdateConstants(func, &params_);
630+
}
631+
632+
// TODO(@areusch, @jroesch): We should refactor this to
633+
// execute as a further pass, instead writing data to the
634+
// lowering process directly.
635+
tec::UpdateFunctionMetadata(func, this->function_metadata_);
636+
});
714637

715-
for (auto input : func->params) {
638+
auto lowered_main = lowered_module.main_module->Lookup("main");
639+
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
640+
641+
// Post-lowering storage map for writing main func
642+
auto new_allocator = AOTOnDemandAllocator();
643+
new_allocator.Run(lowered_main_func);
644+
storage_device_map_ = new_allocator.GetStorageMap();
645+
646+
for (auto input : lowered_main_func->params) {
716647
input_vars_.push_back(input);
717648
main_signature_.push_back(tir::Var("input", DataType::Handle()));
718649
}
@@ -732,13 +663,14 @@ class AOTExecutorCodegen : public ExprVisitor {
732663
main_signature_.push_back(tir::Var("output", DataType::Handle()));
733664
}
734665

735-
VisitExpr(func->body);
666+
VisitExpr(lowered_main_func->body);
736667

737668
// Create the runner function. Please note that the function is not legal yet
738669
// because the packed calls arguments are not wrapped in TVMValues. To make this happen we need
739670
// to run the LegalizePackedCalls pass.
740-
auto prim_func = CreateMainFunc(func->params.size());
741-
UpdateMainWorkspaceSize(prim_func, func);
671+
auto prim_func = CreateMainFunc(mod_name, lowered_main_func->params.size());
672+
UpdateMainWorkspaceSize(prim_func, lowered_main_func);
673+
742674
LoweredOutput ret;
743675

744676
ret.params = std::unordered_map<std::string, std::pair<int, const tvm::runtime::NDArray>>();
@@ -748,17 +680,7 @@ class AOTExecutorCodegen : public ExprVisitor {
748680
std::make_pair(static_cast<int>(param_storage_ids_[param.first]), param.second)));
749681
}
750682

751-
for (auto& kv : lowered_funcs_) {
752-
if (ret.lowered_funcs.count(kv.first) == 0) {
753-
ret.lowered_funcs.Set(kv.first, IRModule(Map<GlobalVar, BaseFunc>({})));
754-
}
755-
auto& mod = ret.lowered_funcs[kv.first];
756-
mod->Update(kv.second);
757-
ret.lowered_funcs.Set(kv.first, mod);
758-
}
759-
ret.external_mods = compile_engine_->LowerExternalFunctions();
760-
761-
// Build the TIR IRModule
683+
// Build the TIR IRModule for the AOT function
762684
Map<GlobalVar, BaseFunc> symbol_map;
763685
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
764686
IRModule mod_run(symbol_map);
@@ -774,14 +696,17 @@ class AOTExecutorCodegen : public ExprVisitor {
774696
mod_run = pack_calls(mod_run);
775697
}
776698

777-
// Update the lowered functions
699+
ret.function_metadata = std::move(function_metadata_);
700+
701+
ret.lowered_funcs = lowered_module.per_target_module;
702+
ret.external_mods = lowered_module.external_mods;
703+
778704
auto target_host_str = target_host_->str();
779705
if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
780706
ret.lowered_funcs[target_host_str]->Update(mod_run);
781707
} else {
782708
ret.lowered_funcs.Set(target_host_str, mod_run);
783709
}
784-
ret.function_metadata = std::move(function_metadata_);
785710

786711
std::vector<String> input_var_names(input_vars_.size());
787712
std::transform(input_vars_.begin(), input_vars_.end(), input_var_names.begin(),
@@ -845,15 +770,15 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
845770

846771
private:
847772
void init(void* mod, Map<Integer, tvm::Target> tmp) {
848-
TargetsMap targets;
773+
tec::TargetMap targets;
849774
Target target_host;
850775
for (const auto& it : tmp) {
851776
auto dev_type = it.first.as<tir::IntImmNode>();
852777
if (!target_host.defined() && it.second->kind->device_type == kDLCPU) {
853778
target_host = it.second;
854779
}
855780
ICHECK(dev_type);
856-
targets[dev_type->value] = it.second;
781+
targets[static_cast<DLDeviceType>(dev_type->value)] = it.second;
857782
}
858783
codegen_ = std::make_shared<AOTExecutorCodegen>(reinterpret_cast<runtime::Module*>(mod),
859784
targets, target_host);

0 commit comments

Comments
 (0)