3838#include < string>
3939#include < vector>
4040
41- #include " compile_engine .h"
41+ #include " te_compiler .h"
4242#include " utils.h"
4343
4444namespace tvm {
4545namespace relay {
4646namespace backend {
4747
4848using IntegerArray = Array<Integer>;
49- using TargetsMap = std::unordered_map<int , Target>;
5049using StorageMap =
5150 std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
5251
@@ -287,7 +286,6 @@ class AOTExecutorCodegen : public ExprVisitor {
287286 void CreateFuncCall (Call call, std::string func_name) {
288287 tvm::Array<PrimExpr> args{tvm::tir::StringImm (func_name)};
289288 std::vector<tir::Stmt> create_func_call_stmts;
290-
291289 // Pack the inputs
292290 for (Expr arg : call->args ) {
293291 if (params_by_expr_.find (arg) != params_by_expr_.end ()) {
@@ -399,121 +397,21 @@ class AOTExecutorCodegen : public ExprVisitor {
399397 function_metadata_.Set (String (runtime::symbol::tvm_module_main), FunctionInfo (fi_node));
400398 }
401399
402- /* !
403- * \brief Update the function metadata for a given cached function and its relay
404- * primitive function.
405- *
406- * \param cfunc The cached function as provided the by the compile engine
407- * \param relay_func The source relay primitive function
408- * \param relay_target The target associated with relay primitive function
409- */
410- void UpdateFunctionMetadata (const CachedFunc& cfunc, const Function& relay_func,
411- const Target& relay_target) {
412- auto fi_node = make_object<FunctionInfoNode>();
413- for (const auto & kv : cfunc->funcs ->functions ) {
414- auto primfunc = Downcast<tir::PrimFunc>(kv.second );
415- auto workspace_byte_alignment =
416- target_host_->GetAttr <Integer>(" workspace-byte-alignment" ).value_or (16 );
417- Integer workspace_size = CalculateWorkspaceBytes (primfunc, workspace_byte_alignment);
418- Target primfunc_target = relay_target;
419- if (primfunc->attrs ->dict .count (" target" )) {
420- primfunc_target = Downcast<Target>(primfunc->attrs ->dict [" target" ]);
421- }
422- fi_node->workspace_sizes .Set (primfunc_target, workspace_size);
423- // Calculating size for I/O
424- for (auto const & param : primfunc->params ) {
425- auto p_shape = primfunc->buffer_map [param]->shape ;
426- int num_of_elements = 1 ;
427- for (const auto & dim_index_expr : p_shape) {
428- if (dim_index_expr->IsInstance <IntImmNode>()) {
429- num_of_elements *= dim_index_expr.as <IntImmNode>()->value ;
430- } else {
431- // If shape is dynamic, we cannot calculate workspace in compile time.
432- num_of_elements = 0 ;
433- }
434- }
435- int element_size = primfunc->buffer_map [param]->dtype .bytes ();
436- fi_node->io_sizes .Set (primfunc_target, element_size * num_of_elements);
437- }
438- fi_node->constant_sizes .Set (primfunc_target, 0 );
439- fi_node->tir_primfuncs .Set (primfunc_target, primfunc);
440- fi_node->relay_primfuncs .Set (primfunc_target, relay_func);
441- }
442- function_metadata_.Set (cfunc->prim_fn_var ->name_hint , FunctionInfo (fi_node));
443- }
444-
445400 void VisitExpr_ (const CallNode* op) override {
446401 // Descend the call tree
447402 for (auto arg : op->args ) {
448403 VisitExpr (arg);
449404 }
450405
451- Expr expr = GetRef<Expr>(op);
452- Function func;
453406 if (op->op .as <OpNode>()) {
454407 LOG (FATAL) << " Operators should be transformed away; try applying"
455408 << " the fuse_ops transformation to the expression." ;
456409 } else if (op->op .as <GlobalVarNode>()) {
457- LOG (FATAL) << " Not implemented" ;
458- } else if (op->op .as <FunctionNode>()) {
459- func = GetRef<Function>(op->op .as <FunctionNode>());
410+ GlobalVar node = GetRef<GlobalVar>(op->op .as <GlobalVarNode>());
411+ CreateFuncCall (GetRef<Call>(op), node->name_hint );
460412 } else {
461413 LOG (FATAL) << " TVM runtime does not support calls to " << op->op ->GetTypeKey ();
462414 }
463- if (!func->HasNonzeroAttr (attr::kPrimitive )) {
464- LOG (FATAL) << " TVM only support calls to primitive functions "
465- << " (i.e functions composed of fusable operator invocations)" ;
466- }
467-
468- Target target;
469-
470- // Handle external function
471- if (func->GetAttr <String>(attr::kCompiler ).defined ()) {
472- target = Target (" ext_dev" );
473- CCacheKey key = CCacheKey (func, target);
474- CachedFunc ext_func = compile_engine_->Lower (key, mod_name_);
475- ICHECK (ext_func.defined ()) << " External function is not defined." ;
476- UpdateConstants (func, ¶ms_);
477-
478- // Generate the TIR function call
479- CreateFuncCall (GetRef<Call>(op), ext_func->prim_fn_var ->name_hint );
480- return ;
481- }
482-
483- ICHECK_GE (storage_device_map_.count (expr), 0 );
484- StorageInfo& sinfo = storage_device_map_[expr];
485- auto call_dev_type = sinfo->device_types [0 ];
486- // Normal Relay Function
487- if (targets_.size () == 1 ) {
488- // homogeneous execution.
489- const auto & it = targets_.begin ();
490- target = (*it).second ;
491- } else {
492- // heterogeneous execution.
493- std::string call_dev_name;
494- if (call_dev_type == 0 ) {
495- call_dev_name = " llvm" ;
496- } else {
497- call_dev_name = runtime::DeviceName (call_dev_type);
498- }
499- if (targets_.count (call_dev_type) == 0 ) {
500- LOG (FATAL) << " No target is provided for device " << call_dev_name;
501- }
502- target = targets_[call_dev_type];
503- }
504-
505- CCacheKey key = CCacheKey (func, target);
506- CachedFunc lowered_func = compile_engine_->Lower (key, mod_name_);
507-
508- if (!lowered_funcs_.count (target->str ())) {
509- lowered_funcs_[target->str ()] = IRModule (Map<GlobalVar, BaseFunc>({}));
510- }
511- lowered_funcs_[target->str ()]->Update (lowered_func->funcs );
512- // Update function metadata via looking at all primfuncs
513- UpdateFunctionMetadata (lowered_func, func, target);
514-
515- // Generate the TIR function call
516- CreateFuncCall (GetRef<Call>(op), lowered_func->prim_fn_var ->name_hint );
517415 }
518416
519417 void VisitExpr_ (const VarNode* op) override {
@@ -598,7 +496,7 @@ class AOTExecutorCodegen : public ExprVisitor {
598496 // Create the main PrimFunc to execute the graph. Please note that
599497 // the packed function calls don't pack their arguments. The AOT
600498 // runner function needs to be legalized by the LegalizePackedCalls pass.
601- tir::PrimFunc CreateMainFunc (unsigned int relay_params) {
499+ tir::PrimFunc CreateMainFunc (String mod_name, unsigned int relay_params) {
602500 tir::Stmt body = tir::SeqStmt (stmts_);
603501
604502 // Allocate the sids
@@ -637,7 +535,7 @@ class AOTExecutorCodegen : public ExprVisitor {
637535 // Define the PrimFunc attributes
638536 Map<String, ObjectRef> dict_attrs;
639537 String run_func_name =
640- runtime::get_name_mangled (mod_name_ , runtime::symbol::tvm_run_func_suffix);
538+ runtime::get_name_mangled (mod_name , runtime::symbol::tvm_run_func_suffix);
641539 dict_attrs.Set (" global_symbol" , run_func_name);
642540 dict_attrs.Set (" runner_function" , Bool (true ));
643541
@@ -654,7 +552,7 @@ class AOTExecutorCodegen : public ExprVisitor {
654552 /* ! \brief input and output variables belonging to the main function signature */
655553 Array<tir::Var> main_signature_;
656554 /* ! \brief target device */
657- TargetsMap targets_;
555+ tec::TargetMap targets_;
658556 /* ! \brief target host */
659557 Target target_host_;
660558 /* !
@@ -684,35 +582,68 @@ class AOTExecutorCodegen : public ExprVisitor {
684582 /* ! \brief mapping sid -> tir::Var */
685583 std::unordered_map<int , te::Var> sids_table_;
686584 /* ! \brief lowered funcs */
687- std::unordered_map<std::string, IRModule> lowered_funcs_;
688- /* ! \brief lowered funcs */
689585 Map<String, FunctionInfo> function_metadata_;
690- /* ! \brief compile engine */
691- CompileEngine compile_engine_;
692586 /* ! \brief the set of statements that make the program */
693587 std::vector<tir::Stmt> stmts_;
694588 /* ! \brief the list of return sids (note that the function might return more then one output */
695589 std::vector<int > return_sid_;
696- /* ! \brief the module name we use to mangle the function names */
697- String mod_name_;
698590
699591 public:
700- AOTExecutorCodegen (runtime::Module* mod, const TargetsMap & targets, Target target_host)
592+ AOTExecutorCodegen (runtime::Module* mod, const tec::TargetMap & targets, Target target_host)
701593 : mod_(mod),
702594 targets_ (targets),
703595 target_host_(target_host),
704- use_unpacked_api_(target_host->GetAttr<Bool>(" unpacked-api" ).value_or(Bool(false ))),
705- compile_engine_(CompileEngine::Global()) {}
596+ use_unpacked_api_(target_host->GetAttr<Bool>(" unpacked-api" ).value_or(Bool(false ))) {}
706597
707598 LoweredOutput Codegen (relay::Function func, String mod_name) {
708599 auto aot_allocator = AOTOnDemandAllocator ();
709600 aot_allocator.Run (func);
710601
711- // Retrieve the storage map
712- storage_device_map_ = aot_allocator.GetStorageMap ();
713- mod_name_ = mod_name;
602+ // Pre-lowering storage map and memory plan
603+ StorageMap initial_storage_map = aot_allocator.GetStorageMap ();
604+ StaticMemoryPlan memory_plan (initial_storage_map);
605+
606+ // Build a map from each operation to device.
607+ tec::DeviceMap device_context_map;
608+ for (const auto & it : memory_plan->expr_to_storage_info ) {
609+ auto expr = it.first ;
610+ auto storage_info = it.second ;
611+ auto device_types = storage_info->device_types ;
612+ // CHECK_EQ(device_types.size(), 1);
613+ tvm::Device dev;
614+ dev.device_id = 0 ;
615+ dev.device_type = device_types[0 ];
616+ device_context_map.insert ({expr, dev});
617+ }
618+
619+ // This first phase moves from implicit use of compile engine,
620+ // to instead explicitly lowering the incoming IRModule, and then
621+ // performing the preexisting AOT executor code generation phase.
622+ IRModule mod = IRModule::FromExpr (func);
623+ auto lowered_module = tec::LowerTE (
624+ mod, targets_, device_context_map, memory_plan, mod_name, [this ](Function func) {
625+ // We need to maintain the constant map for external
626+ // functions so we pass this processing function which
627+ // allows us to process each function as we lower it.
628+ if (func->GetAttr <String>(attr::kCompiler ).defined ()) {
629+ UpdateConstants (func, ¶ms_);
630+ }
631+
632+ // TODO(@areusch, @jroesch): We should refactor this to
633+ // execute as a further pass, instead writing data to the
634+ // lowering process directly.
635+ tec::UpdateFunctionMetadata (func, this ->function_metadata_ );
636+ });
714637
715- for (auto input : func->params ) {
638+ auto lowered_main = lowered_module.main_module ->Lookup (" main" );
639+ auto lowered_main_func = GetRef<Function>(lowered_main.as <FunctionNode>());
640+
641+ // Post-lowering storage map for writing main func
642+ auto new_allocator = AOTOnDemandAllocator ();
643+ new_allocator.Run (lowered_main_func);
644+ storage_device_map_ = new_allocator.GetStorageMap ();
645+
646+ for (auto input : lowered_main_func->params ) {
716647 input_vars_.push_back (input);
717648 main_signature_.push_back (tir::Var (" input" , DataType::Handle ()));
718649 }
@@ -732,13 +663,14 @@ class AOTExecutorCodegen : public ExprVisitor {
732663 main_signature_.push_back (tir::Var (" output" , DataType::Handle ()));
733664 }
734665
735- VisitExpr (func ->body );
666+ VisitExpr (lowered_main_func ->body );
736667
737668 // Create the runner function. Please note that the function is not legal yet
738669 // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need
739670 // to run the LegalizePackedCalls pass.
740- auto prim_func = CreateMainFunc (func->params .size ());
741- UpdateMainWorkspaceSize (prim_func, func);
671+ auto prim_func = CreateMainFunc (mod_name, lowered_main_func->params .size ());
672+ UpdateMainWorkspaceSize (prim_func, lowered_main_func);
673+
742674 LoweredOutput ret;
743675
744676 ret.params = std::unordered_map<std::string, std::pair<int , const tvm::runtime::NDArray>>();
@@ -748,17 +680,7 @@ class AOTExecutorCodegen : public ExprVisitor {
748680 std::make_pair (static_cast <int >(param_storage_ids_[param.first ]), param.second )));
749681 }
750682
751- for (auto & kv : lowered_funcs_) {
752- if (ret.lowered_funcs .count (kv.first ) == 0 ) {
753- ret.lowered_funcs .Set (kv.first , IRModule (Map<GlobalVar, BaseFunc>({})));
754- }
755- auto & mod = ret.lowered_funcs [kv.first ];
756- mod->Update (kv.second );
757- ret.lowered_funcs .Set (kv.first , mod);
758- }
759- ret.external_mods = compile_engine_->LowerExternalFunctions ();
760-
761- // Build the TIR IRModule
683+ // Build the TIR IRModule for the AOT function
762684 Map<GlobalVar, BaseFunc> symbol_map;
763685 symbol_map.Set (GlobalVar (::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
764686 IRModule mod_run (symbol_map);
@@ -774,14 +696,17 @@ class AOTExecutorCodegen : public ExprVisitor {
774696 mod_run = pack_calls (mod_run);
775697 }
776698
777- // Update the lowered functions
699+ ret.function_metadata = std::move (function_metadata_);
700+
701+ ret.lowered_funcs = lowered_module.per_target_module ;
702+ ret.external_mods = lowered_module.external_mods ;
703+
778704 auto target_host_str = target_host_->str ();
779705 if (ret.lowered_funcs .find (target_host_str) != ret.lowered_funcs .end ()) {
780706 ret.lowered_funcs [target_host_str]->Update (mod_run);
781707 } else {
782708 ret.lowered_funcs .Set (target_host_str, mod_run);
783709 }
784- ret.function_metadata = std::move (function_metadata_);
785710
786711 std::vector<String> input_var_names (input_vars_.size ());
787712 std::transform (input_vars_.begin (), input_vars_.end (), input_var_names.begin (),
@@ -845,15 +770,15 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
845770
846771 private:
847772 void init (void * mod, Map<Integer, tvm::Target> tmp) {
848- TargetsMap targets;
773+ tec::TargetMap targets;
849774 Target target_host;
850775 for (const auto & it : tmp) {
851776 auto dev_type = it.first .as <tir::IntImmNode>();
852777 if (!target_host.defined () && it.second ->kind ->device_type == kDLCPU ) {
853778 target_host = it.second ;
854779 }
855780 ICHECK (dev_type);
856- targets[dev_type->value ] = it.second ;
781+ targets[static_cast <DLDeviceType>( dev_type->value ) ] = it.second ;
857782 }
858783 codegen_ = std::make_shared<AOTExecutorCodegen>(reinterpret_cast <runtime::Module*>(mod),
859784 targets, target_host);
0 commit comments