3131#include < tvm/tir/expr.h>
3232#include < tvm/tir/function.h>
3333#include < tvm/tir/stmt.h>
34+ #include < tvm/tir/transform.h>
3435
3536#include < algorithm>
3637#include < list>
@@ -44,52 +45,179 @@ namespace tvm {
4445namespace relay {
4546namespace backend {
4647
48+ /* *
49+ * Struct to contain information about intermediate variables in the
50+ * runner function
51+ */
52+ struct StorageInfo {
53+ /* ! \brief unique integer identifier of the particular intermediate variable */
54+ std::vector<int > ids;
55+ /* ! \brief exact size of the temporary */
56+ std::vector<int > sizes_bytes;
57+ /* ! \brief device type of the temporary variable */
58+ std::vector<int > dev_types;
59+ };
60+
4761using IntegerArray = Array<Integer>;
4862using TargetsMap = std::unordered_map<int , Target>;
63+ using StorageMap =
64+ std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
4965
50- class AotReturnSidVisitor : public ExprVisitor {
66+ /* *
67+ * This is an on demand allocator for AOT. A new temporary
68+ * (storage allocator identifier) is allocated for each operation.
69+ */
70+ class AOTOnDemandAllocator : public ExprVisitor {
5171 public:
52- explicit AotReturnSidVisitor (Map<Expr, Array<IntegerArray>> storage_device_map)
53- : storage_device_map_{storage_device_map}, return_sid_{-1 } {}
72+ // run the visitor on a function.
73+ void Run (const Function& func) {
74+ node_device_map_ = CollectDeviceInfo (func);
5475
55- IntegerArray FindReturnSid (Function func) {
56- VisitExpr (func->body );
57- return return_sid_;
76+ for (Expr param : func->params ) {
77+ CreateSid (param.operator ->());
78+ }
79+
80+ GetSid (func->body );
5881 }
5982
60- protected:
61- void AssignReturnSid (Expr e) {
62- auto iter = storage_device_map_.find (e);
63- if (iter != storage_device_map_.end ()) {
64- return_sid_ = (*iter).second [0 ];
83+ std::vector<int > GetReturnIds () const { return return_ids_; }
84+
85+ StorageMap GetStorageMap () const { return storage_device_map_; }
86+
87+ void VisitExpr_ (const ConstantNode* op) final {
88+ CreateSid (op);
89+ AssignReturnSid (GetRef<Expr>(op));
90+ }
91+
92+ void VisitExpr_ (const CallNode* op) final {
93+ // create token for the call node.
94+ CreateSid (op);
95+ for (Expr arg : op->args ) {
96+ GetSid (arg);
6597 }
98+ AssignReturnSid (GetRef<Expr>(op));
6699 }
67100
68- void VisitExpr_ (const ConstantNode* cn) override {
69- ExprVisitor::VisitExpr_ (cn );
70- AssignReturnSid (GetRef<Expr>(cn ));
101+ void VisitExpr_ (const VarNode* op) final {
102+ ExprVisitor::VisitExpr_ (op );
103+ AssignReturnSid (GetRef<Expr>(op ));
71104 }
72105
73- void VisitExpr_ (const VarNode* vn) override {
74- ExprVisitor::VisitExpr_ (vn);
75- AssignReturnSid (GetRef<Expr>(vn));
106+ void VisitExpr_ (const FunctionNode* op) final {
107+ // do not recurse into sub function.
76108 }
77109
78- void VisitExpr_ (const CallNode* cn) override {
79- ExprVisitor::VisitExpr_ (cn);
80- AssignReturnSid (GetRef<Expr>(cn));
110+ void VisitExpr_ (const GlobalVarNode* op) final {
111+ // Do nothing.
81112 }
82113
83- void VisitExpr_ (const LetNode* op) override { VisitExpr (op->body ); }
114+ void VisitExpr_ (const OpNode* op) final {
115+ // Do nothing.
116+ }
117+
118+ void VisitExpr_ (const TupleNode* op) final {
119+ StorageInfo field_sid;
120+ Expr expr = GetRef<Expr>(op);
121+ for (Expr field : op->fields ) {
122+ auto sid = GetSid (field);
123+ field_sid.ids .insert (field_sid.ids .end (), sid.ids .begin (), sid.ids .end ());
124+ field_sid.dev_types .insert (field_sid.dev_types .end (), sid.dev_types .begin (),
125+ sid.dev_types .end ());
126+ field_sid.sizes_bytes .insert (field_sid.sizes_bytes .end (), sid.sizes_bytes .begin (),
127+ sid.sizes_bytes .end ());
128+ }
129+
130+ storage_device_map_[expr] = field_sid;
131+ AssignReturnSid (expr);
132+ }
84133
85- void VisitExpr_ (const TupleNode* tn) override {
86- ExprVisitor::VisitExpr_ (tn);
87- AssignReturnSid (GetRef<Expr>(tn));
134+ void VisitExpr_ (const TupleGetItemNode* op) final {
135+ Expr expr = GetRef<Expr>(op);
136+ const auto & sid = GetSid (op->tuple );
137+ ICHECK_LT (static_cast <size_t >(op->index ), sid.ids .size ());
138+ storage_device_map_[expr].ids = {sid.ids [op->index ]};
139+ storage_device_map_[expr].sizes_bytes = {sid.sizes_bytes [op->index ]};
140+ storage_device_map_[expr].dev_types = {sid.dev_types [op->index ]};
141+ AssignReturnSid (expr);
88142 }
89143
144+ void VisitExpr_ (const IfNode* op) final { LOG (FATAL) << " if is not supported." ; }
145+
146+ void VisitExpr_ (const LetNode* op) final { LOG (FATAL) << " if is not supported." ; }
147+
90148 private:
91- Map<Expr, Array<IntegerArray>> storage_device_map_;
92- IntegerArray return_sid_;
149+ void AssignReturnSid (Expr e) {
150+ auto iter = storage_device_map_.find (e);
151+ if (iter != storage_device_map_.end ()) {
152+ return_ids_ = (*iter).second .ids ;
153+ }
154+ }
155+ /* !
156+ * \brief ceil(size/word_size) to get number of words.
157+ * \param size The original size.
158+ * \param word_size The element size.
159+ */
160+ static size_t DivRoundUp (size_t size, size_t word_size) {
161+ return (size + word_size - 1 ) / word_size;
162+ }
163+ /* !
164+ * \brief Get the memory requirement.
165+ * \param prototype The prototype token.
166+ * \return The required memory size.
167+ */
168+ size_t GetMemorySize (const TensorTypeNode* ttype) {
169+ ICHECK (ttype != nullptr );
170+ size_t size = 1 ;
171+ for (IndexExpr dim : ttype->shape ) {
172+ const int64_t * pval = tir::as_const_int (dim);
173+ ICHECK (pval != nullptr ) << " Cannot allocate memory symbolic tensor shape " << ttype->shape ;
174+ ICHECK_GE (*pval, 0 ) << " Cannot allocate memory for tensor with negative shape" << *pval;
175+ size *= static_cast <size_t >(pval[0 ]);
176+ }
177+ size *= DivRoundUp (ttype->dtype .bits () * ttype->dtype .lanes (), 8 );
178+ return size;
179+ }
180+ /* !
181+ * \brief Get the necessary token.
182+ * \param expr The expression.
183+ * \return The corresponding token.
184+ */
185+ StorageInfo GetSid (const Expr& expr) {
186+ this ->VisitExpr (expr);
187+ auto it = storage_device_map_.find (expr);
188+ ICHECK (it != storage_device_map_.end ());
189+ return it->second ;
190+ }
191+
192+ void CreateSid (const ExprNode* op) {
193+ StorageInfo sid;
194+ Expr expr = GetRef<Expr>(op);
195+ int device_type = node_device_map_.count (GetRef<Expr>(op)) ? node_device_map_[expr]->value : 0 ;
196+ if (const auto * tuple_type = op->checked_type ().as <TupleTypeNode>()) {
197+ for (Type t : tuple_type->fields ) {
198+ const auto * ttype = t.as <TensorTypeNode>();
199+ ICHECK (ttype);
200+ sid.ids .push_back (sid_++);
201+ sid.dev_types .push_back (device_type);
202+ sid.sizes_bytes .push_back (GetMemorySize (ttype));
203+ }
204+ } else {
205+ const auto * ttype = op->checked_type ().as <TensorTypeNode>();
206+ ICHECK (ttype);
207+ sid.ids .push_back (sid_++);
208+ sid.dev_types .push_back (device_type);
209+ sid.sizes_bytes .push_back (GetMemorySize (ttype));
210+ }
211+ storage_device_map_[expr] = sid;
212+ }
213+ /* ! \brief mapping of expression -> storageInfo*/
214+ StorageMap storage_device_map_;
215+ /* ! \brief mapping of expression -> device type*/
216+ Map<Expr, Integer> node_device_map_;
217+ /* ! \brief current id of the temporary allocated*/
218+ int sid_{0 };
219+ /* ! \brief the set of identifiers that are return variables */
220+ std::vector<int > return_ids_;
93221};
94222
95223/* ! \brief Code generator for AOT executor */
@@ -120,14 +248,14 @@ class AOTExecutorCodegen : public ExprVisitor {
120248 * \brief Return a vector of variables that represents the sids for the given Relay Expr
121249 */
122250 std::vector<tir::Var> PackSid (Expr expr) {
123- Array<IntegerArray> sids = storage_device_map_[expr];
251+ auto sids = storage_device_map_[expr];
124252 std::vector<tir::Var> sid_vars;
125253
126254 // Note that an expression can have multiple sids associated with it
127255 // e.g., returning multiple values from a function
128- for (const auto & sid : sids[ 0 ] ) {
256+ for (const auto & sid : sids. ids ) {
129257 // Determine if an sid is an output buffer
130- int sid_int = static_cast < int >(( sid. as <IntImmNode>())-> value ) ;
258+ int sid_int = sid;
131259 auto output_iter = std::find (return_sid_.begin (), return_sid_.end (), sid_int);
132260 if (output_iter != return_sid_.end ()) {
133261 int output_index = std::distance (return_sid_.begin (), output_iter);
@@ -390,8 +518,8 @@ class AOTExecutorCodegen : public ExprVisitor {
390518 }
391519
392520 ICHECK_GE (storage_device_map_.count (expr), 0 );
393- auto & device_type = storage_device_map_[expr][ 1 ] ;
394- auto call_dev_type = device_type[0 ]-> value ;
521+ auto & device_type = storage_device_map_[expr]. dev_types ;
522+ auto call_dev_type = device_type[0 ];
395523 // Normal Relay Function
396524 if (targets_.size () == 1 ) {
397525 // homogeneous execution.
@@ -428,14 +556,14 @@ class AOTExecutorCodegen : public ExprVisitor {
428556
429557 // If the Var node is an output node we need to copy the content of the variable to the output
430558 // It's safe to check the SID here because Var StorageToken are never reallocated
431- Array<IntegerArray> sids = storage_device_map_[expr];
559+ auto sids = storage_device_map_[expr];
432560
433- auto output_iter = std::find (return_sid_.begin (), return_sid_.end (),
434- static_cast <int >((sids[0 ][0 ].as <IntImmNode>())->value ));
561+ auto output_iter = std::find (return_sid_.begin (), return_sid_.end (), sids.ids [0 ]);
435562 if (output_iter != return_sid_.end ()) {
436563 int output_index = std::distance (return_sid_.begin (), output_iter);
437564 auto var_expr = FindExpr (expr);
438- CopyToOutput (main_signature_[input_vars_.size () + output_index], var_expr[0 ], sids[2 ][0 ]);
565+ CopyToOutput (main_signature_[input_vars_.size () + output_index], var_expr[0 ],
566+ sids.sizes_bytes [0 ]);
439567 }
440568 }
441569
@@ -444,18 +572,18 @@ class AOTExecutorCodegen : public ExprVisitor {
444572 size_t index = params_.size ();
445573 std::string name = " p" + std::to_string (index);
446574
447- param_storage_ids_[name] = storage_device_map_[expr][0 ][ 0 ]-> value ;
575+ param_storage_ids_[name] = storage_device_map_[expr]. ids [0 ];
448576 params_[name] = op->data ;
449577 params_by_expr_.Set (expr, name);
450578
451579 // If the Constant node is an output node we need to copy the content of the parameter to the
452580 // output A Var node can only produce a single output
453- Array<IntegerArray> sids = storage_device_map_[expr];
454- auto output_iter = std::find (return_sid_.begin (), return_sid_.end (),
455- static_cast <int >((sids[0 ][0 ].as <IntImmNode>())->value ));
581+ auto sids = storage_device_map_[expr];
582+ auto output_iter = std::find (return_sid_.begin (), return_sid_.end (), sids.ids [0 ]);
456583 if (output_iter != return_sid_.end ()) {
457584 int output_index = std::distance (return_sid_.begin (), output_iter);
458- CopyToOutput (main_signature_[input_vars_.size () + output_index], PackParam (expr), sids[2 ][0 ]);
585+ CopyToOutput (main_signature_[input_vars_.size () + output_index], PackParam (expr),
586+ sids.sizes_bytes [0 ]);
459587 }
460588 }
461589
@@ -511,9 +639,9 @@ class AOTExecutorCodegen : public ExprVisitor {
511639 continue ;
512640 }
513641
514- for (unsigned int i = 0 ; i < kv.second [ 0 ] .size (); i++) {
515- int size = kv.second [ 2 ] [i];
516- int sid = static_cast < int >(( kv.second [ 0 ][i]. as <IntImmNode>())-> value ) ;
642+ for (unsigned int i = 0 ; i < kv.second . ids .size (); i++) {
643+ int size = kv.second . sizes_bytes [i];
644+ int sid = kv.second . ids [i] ;
517645
518646 if (std::find (return_sid_.begin (), return_sid_.end (), sid) != return_sid_.end ()) {
519647 continue ;
@@ -523,6 +651,8 @@ class AOTExecutorCodegen : public ExprVisitor {
523651 // so we don't pay the price of allocation for every inference
524652 if (!allocated[sid]) {
525653 body = tir::Allocate (sids_table_[sid], DataType::Int (8 ), {size}, tir::const_true (), body);
654+ body = tir::AttrStmt (sids_table_[sid], tir::attr::storage_scope, tir::StringImm (" global" ),
655+ body);
526656 }
527657 allocated[sid] = true ;
528658 }
@@ -566,7 +696,8 @@ class AOTExecutorCodegen : public ExprVisitor {
566696 std::unordered_map<std::string, int64_t > param_storage_ids_;
567697
568698 /* ! \brief plan memory of device result */
569- Map<Expr, Array<IntegerArray>> storage_device_map_;
699+ StorageMap storage_device_map_;
700+ /* ! \brief mapping sid -> tir::Var */
570701 std::unordered_map<int , te::Var> sids_table_;
571702 /* ! \brief lowered funcs */
572703 std::unordered_map<std::string, IRModule> lowered_funcs_;
@@ -577,7 +708,7 @@ class AOTExecutorCodegen : public ExprVisitor {
577708 /* ! \brief the set of statements that make the program */
578709 std::vector<tir::Stmt> stmts_;
579710 /* ! \brief the list of return sids (note that the function might return more then one output */
580- IntegerArray return_sid_;
711+ std::vector< int > return_sid_;
581712
582713 public:
583714 AOTExecutorCodegen (runtime::Module* mod, const TargetsMap& targets, Target target_host)
@@ -588,9 +719,11 @@ class AOTExecutorCodegen : public ExprVisitor {
588719 }
589720
590721 LoweredOutput Codegen (relay::Function func) {
591- // Get the module, storage map and token sizes
592- auto pf = GetPackedFunc (" relay.backend.GraphPlanMemory" );
593- storage_device_map_ = (*pf)(func);
722+ auto aot_allocator = AOTOnDemandAllocator ();
723+ aot_allocator.Run (func);
724+
725+ // Retrieve the storage map
726+ storage_device_map_ = aot_allocator.GetStorageMap ();
594727
595728 int input_index = 0 ;
596729 for (auto input : func->params ) {
@@ -600,14 +733,14 @@ class AOTExecutorCodegen : public ExprVisitor {
600733
601734 // Define the storage allocator ids
602735 for (auto kv : storage_device_map_) {
603- for (const auto & sid : kv.second [ 0 ] ) {
736+ for (const auto & sid : kv.second . ids ) {
604737 te::Var sid_var (MakeString (" sid_" , sid), PointerType (PrimType (DataType::Int (8 ))));
605738 sids_table_[sid] = sid_var;
606739 }
607740 }
608741
609- // Find the return sid
610- return_sid_ = AotReturnSidVisitor (storage_device_map_). FindReturnSid (func );
742+ // Retrieve the return sids
743+ return_sid_ = aot_allocator. GetReturnIds ( );
611744 for (unsigned int output_index = 0 ; output_index < return_sid_.size (); output_index++) {
612745 main_signature_.push_back (tir::Var (MakeString (" output_" , output_index), DataType::Handle ()));
613746 }
@@ -635,14 +768,21 @@ class AOTExecutorCodegen : public ExprVisitor {
635768 }
636769 ret.external_mods = compile_engine_->LowerExternalFunctions ();
637770
771+ // Build the TIR IRModule
772+ Map<GlobalVar, BaseFunc> symbol_map;
773+ symbol_map.Set (GlobalVar (::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
774+ IRModule mod_run (symbol_map);
775+
776+ // Apply storage rewrite pass to the runner function to do memory planning
777+ auto storage_rewrite = tir::transform::StorageRewrite ();
778+ mod_run = storage_rewrite (mod_run);
779+
780+ // Update the lowered functions
638781 auto target_host_str = target_host_->str ();
639782 if (ret.lowered_funcs .find (target_host_str) != ret.lowered_funcs .end ()) {
640- ret.lowered_funcs [target_host_str]->Add (
641- GlobalVar (::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
783+ ret.lowered_funcs [target_host_str]->Update (mod_run);
642784 } else {
643- Map<GlobalVar, BaseFunc> symbol_map;
644- symbol_map.Set (GlobalVar (::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
645- ret.lowered_funcs .Set (target_host_str, IRModule (symbol_map));
785+ ret.lowered_funcs .Set (target_host_str, mod_run);
646786 }
647787 ret.function_metadata = std::move (function_metadata_);
648788 ret.metadata =
0 commit comments