@@ -39,8 +39,9 @@ namespace usmp {
3939class PoolAllocationToOffsetConverter : public StmtExprMutator {
4040 public:
4141 explicit PoolAllocationToOffsetConverter (const IRModule& module ,
42- const Map<tir::Stmt, PoolAllocation>& pool_allocations)
43- : pool_allocations_(pool_allocations) {
42+ const Map<tir::Stmt, PoolAllocation>& pool_allocations,
43+ bool emit_tvmscript_printable = false )
44+ : pool_allocations_(pool_allocations), emit_tvmscript_printable_(emit_tvmscript_printable) {
4445 module_ = module ->ShallowCopy ();
4546 for (const auto & gv_func : module_->functions ) {
4647 function_global_vars_.Set (gv_func.first ->name_hint , gv_func.first );
@@ -51,7 +52,6 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
5152 Allocate allocate_node = Downcast<Allocate>(kv.first );
5253 PoolAllocation pool_allocation = kv.second ;
5354 PoolInfo pool_info = pool_allocation->pool_info ;
54- pool_ordering_.insert (pool_info);
5555 int byte_pool_offset = pool_allocation->byte_offset ->value ;
5656 int required_pool_size_for_allocation =
5757 byte_pool_offset + CalculateExtentsSize (allocate_node.operator ->());
@@ -64,12 +64,26 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
6464 }
6565 }
6666 }
67+
68+ for (const auto & kv : all_pools_sizes_) {
69+ PoolInfo pi = kv.first ;
70+ int allocated_size = kv.second ;
71+ allocated_pool_ordering_.push_back (AllocatedPoolInfo (pi, allocated_size));
72+ }
73+ std::sort (allocated_pool_ordering_.begin (), allocated_pool_ordering_.end (),
74+ [](const AllocatedPoolInfo& lhs, const AllocatedPoolInfo& rhs) {
75+ if (lhs->pool_info ->pool_name < rhs->pool_info ->pool_name ) {
76+ return true ;
77+ }
78+ return false ;
79+ });
6780 }
6881 IRModule operator ()();
6982
7083 private:
7184 PrimExpr VisitExpr_ (const CallNode* op) override ;
7285 Stmt VisitStmt_ (const AllocateNode* op) override ;
86+ // PrimExpr VisitExpr_(const VarNode* op) override;
7387 PrimExpr VisitExpr_ (const LoadNode* op) override ;
7488 Stmt VisitStmt_ (const StoreNode* op) override ;
7589
@@ -79,6 +93,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
7993 struct ScopeInfo {
8094 Array<tir::Var> params;
8195 Map<PoolInfo, tir::Var> pools_to_params;
96+ Array<AllocatedPoolInfo> allocated_pool_params;
8297 Map<tir::Var, Buffer> buffer_map;
8398 };
8499
@@ -101,7 +116,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
101116 /* ! \brief This is a helper to append the pool args to
102117 * the callsite of the function.
103118 */
104- Array<PrimExpr> AppendPoolParamsToArgs (const CallNode* op );
119+ Array<PrimExpr> AppendPoolParamsToArgs (const Array<PrimExpr>& args );
105120 /* ! \brief Some arguments that used to be Allocate nodes
106121 * should be replaced by Let nodes in the pass that loads
107122 * the space from a pool variable.
@@ -117,7 +132,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
117132 /* ! \brief The input allocate node to PoolAllocation map */
118133 Map<tir::Stmt, PoolAllocation> pool_allocations_;
119134 /* ! \brief The set of ordered pools to ensure an unique order of args for functions */
120- std::set<PoolInfo> pool_ordering_ ;
135+ std::vector<AllocatedPoolInfo> allocated_pool_ordering_ ;
121136 /* ! \brief The storage of calculated pool size at init */
122137 std::unordered_map<PoolInfo, int , ObjectPtrHash, ObjectPtrEqual> all_pools_sizes_;
123138 /* ! \brief The AoT codegen uses extern_calls due to some functions not being exposed in the TIR
@@ -130,6 +145,10 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
130145 Map<tir::Var, tir::Var> allocate_buf_to_let_var_;
131146 /* ! \brief A counter to give references to pools a reproducible unique set of names */
132147 int pool_var_count_ = 0 ;
148+ /* ! \brief This toggles to remove non tvmscript printable items for IRModule for unit tests */
149+ bool emit_tvmscript_printable_ = false ;
150+ /* ! \brief A counter to give references to pools a reproducible unique set of names */
151+ std::unordered_set<PrimFunc, ObjectPtrHash, ObjectPtrEqual> visited_primfuncs;
133152};
134153
135154PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::UpdateFunctionScopeInfo (
@@ -138,14 +157,22 @@ PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::Upda
138157 si.params = original_func->params ;
139158 si.buffer_map = original_func->buffer_map ;
140159 Map<tir::Var, PoolInfo> ret;
141- for (const PoolInfo& pool_info : pool_ordering_) {
160+ for (const AllocatedPoolInfo& allocated_pool_info : allocated_pool_ordering_) {
161+ PoolInfo pool_info = allocated_pool_info->pool_info ;
142162 String pool_ref_name = pool_info->pool_name + " _" + std::to_string (pool_var_count_++);
143163 String var_name = pool_ref_name + " _var" ;
144164 DataType elem_dtype = DataType::UInt (8 );
145165 Var buffer_var (var_name, PointerType (PrimType (elem_dtype), " global" ));
146- Var pool_var (var_name, DataType::Handle ());
166+ Var pool_var;
167+ if (!emit_tvmscript_printable_) {
168+ pool_var = Var (var_name, PointerType (PrimType (elem_dtype), " global" ));
169+ } else {
170+ pool_var = Var (var_name, DataType::Handle (8 ));
171+ }
147172 si.params .push_back (pool_var);
148173 si.pools_to_params .Set (pool_info, pool_var);
174+ si.allocated_pool_params .push_back (AllocatedPoolInfo (
175+ allocated_pool_info->pool_info , allocated_pool_info->allocated_size , pool_var));
149176
150177 int pool_size = all_pools_sizes_[pool_info];
151178 String buffer_var_name = pool_ref_name + " _buffer_var" ;
@@ -157,22 +184,40 @@ PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::Upda
157184
158185PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams (
159186 const PrimFunc& original_primfunc) {
160- ScopeInfo si = UpdateFunctionScopeInfo (original_primfunc);
161- this ->scope_stack .push (si);
162- Stmt new_body = this ->VisitStmt (original_primfunc->body );
163- this ->scope_stack .pop ();
164- return PrimFunc (si.params , new_body, original_primfunc->ret_type , si.buffer_map ,
165- original_primfunc->attrs );
187+ // Only create the new function if it was not modified with pool params
188+ if (visited_primfuncs.find (original_primfunc) == visited_primfuncs.end ()) {
189+ ScopeInfo si = UpdateFunctionScopeInfo (original_primfunc);
190+ this ->scope_stack .push (si);
191+ Stmt new_body = this ->VisitStmt (original_primfunc->body );
192+ this ->scope_stack .pop ();
193+ DictAttrs original_attrs = original_primfunc->attrs ;
194+ // We dont need attrs of PrimFunc that might include non printable attrs such as target
195+ // for unit tests where emit_tvmscript_printable_ is to be used.
196+ if (emit_tvmscript_printable_) {
197+ original_attrs = DictAttrs ();
198+ }
199+ PrimFunc ret =
200+ PrimFunc (si.params , new_body, original_primfunc->ret_type , si.buffer_map , original_attrs);
201+ if (!emit_tvmscript_printable_) {
202+ return WithAttr (ret, tvm::attr::kPoolArgs , si.allocated_pool_params );
203+ }
204+ visited_primfuncs.insert (ret);
205+ return ret;
206+ }
207+ return original_primfunc;
166208}
167209
168- Array<PrimExpr> PoolAllocationToOffsetConverter::AppendPoolParamsToArgs (const CallNode* op) {
210+ Array<PrimExpr> PoolAllocationToOffsetConverter::AppendPoolParamsToArgs (
211+ const Array<PrimExpr>& args) {
169212 Array<PrimExpr> new_args;
170- for (const auto & arg : op-> args ) {
213+ for (const auto & arg : args) {
171214 new_args.push_back (VisitExpr (arg));
172215 }
173- for (const auto & pools_vars : this ->scope_stack .top ().pools_to_params ) {
216+ ScopeInfo top_scope = this ->scope_stack .top ();
217+ for (const auto & pools_vars : top_scope.pools_to_params ) {
174218 tir::Var pool_var = pools_vars.second ;
175- new_args.push_back (pool_var);
219+ Buffer buffer_var = top_scope.buffer_map [pool_var];
220+ new_args.push_back (buffer_var->data );
176221 }
177222 return new_args;
178223}
@@ -192,24 +237,30 @@ Array<PrimExpr> PoolAllocationToOffsetConverter::ReplaceAllocateArgsWithLetArgs(
192237}
193238
194239PrimExpr PoolAllocationToOffsetConverter::VisitExpr_ (const CallNode* op) {
195- if (op->op .same_as (builtin::call_extern ())) {
240+ if (op->op .same_as (builtin::call_extern ()) || op-> op . same_as ( builtin::tvm_call_cpacked ()) ) {
196241 String func_name = Downcast<StringImm>(op->args [0 ])->value ;
197- GlobalVar gv = function_global_vars_.at (func_name);
198- PrimFunc func = Downcast<PrimFunc>(module_->Lookup (gv));
199- PrimFunc prim_func = CreatePrimFuncWithPoolParams (func);
200- module_->Update (gv, prim_func);
201- Array<PrimExpr> new_args = AppendPoolParamsToArgs (op);
202- new_args = ReplaceAllocateArgsWithLetArgs (new_args);
203- return Call (op->dtype , builtin::call_extern (), new_args);
204- } else if (op->op ->IsInstance <PrimFuncNode>()) {
242+ Array<PrimExpr> new_args;
243+ if (function_global_vars_.find (func_name) != function_global_vars_.end ()) {
244+ GlobalVar gv = function_global_vars_.at (func_name);
245+ PrimFunc func = Downcast<PrimFunc>(module_->Lookup (gv));
246+ PrimFunc prim_func = CreatePrimFuncWithPoolParams (func);
247+ module_->Update (gv, prim_func);
248+ new_args = AppendPoolParamsToArgs (op->args );
249+ new_args = ReplaceAllocateArgsWithLetArgs (new_args);
250+ } else {
251+ new_args = ReplaceAllocateArgsWithLetArgs (op->args );
252+ }
253+ return Call (op->dtype , op->op , new_args);
254+ }
255+ if (op->op ->IsInstance <PrimFuncNode>()) {
205256 PrimFunc func = Downcast<PrimFunc>(op->op );
206257 PrimFunc prim_func = CreatePrimFuncWithPoolParams (func);
207- Array<PrimExpr> new_args = AppendPoolParamsToArgs (op);
258+ Array<PrimExpr> new_args = AppendPoolParamsToArgs (op->args );
259+ new_args = AppendPoolParamsToArgs (new_args);
208260 new_args = ReplaceAllocateArgsWithLetArgs (new_args);
209261 return Call (op->dtype , prim_func, new_args);
210- } else {
211- return StmtExprMutator::VisitExpr_ (op);
212262 }
263+ return StmtExprMutator::VisitExpr_ (op);
213264}
214265
215266Stmt PoolAllocationToOffsetConverter::VisitStmt_ (const AllocateNode* op) {
@@ -219,12 +270,19 @@ Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateNode* op) {
219270 Var param = scope_info.pools_to_params [pool_allocation->pool_info ];
220271 Buffer buffer_var = scope_info.buffer_map [param];
221272 ICHECK (pool_allocation->byte_offset < all_pools_sizes_[pool_allocation->pool_info ]);
222- Load load_node = Load (op->dtype , buffer_var->data , pool_allocation->byte_offset , op->condition );
223- Var tir_var (op->buffer_var ->name_hint + " _let" , op->dtype );
273+ Load load_node =
274+ Load (DataType::UInt (8 ), buffer_var->data , pool_allocation->byte_offset , op->condition );
275+ Call address_of_load = Call (DataType::Handle (8 ), builtin::address_of (), {load_node});
276+ Var tir_var;
277+ if (!emit_tvmscript_printable_) {
278+ tir_var = Var (op->buffer_var ->name_hint + " _let" , op->buffer_var ->type_annotation );
279+ } else {
280+ tir_var = Var (op->buffer_var ->name_hint + " _let" , DataType::Handle (8 ));
281+ }
224282 allocate_buf_to_let_var_.Set (op->buffer_var , tir_var);
225283 Stmt new_body = VisitStmt (op->body );
226284 allocate_buf_to_let_var_.erase (op->buffer_var );
227- return LetStmt (tir_var, load_node , new_body);
285+ return LetStmt (tir_var, address_of_load , new_body);
228286 }
229287 return StmtExprMutator::VisitStmt_ (op);
230288}
@@ -252,17 +310,31 @@ IRModule PoolAllocationToOffsetConverter::operator()() {
252310 this ->scope_stack .push (si);
253311 Stmt main_func_body = this ->VisitStmt (main_func->body );
254312 this ->scope_stack .pop ();
255- module_->Update (gv, PrimFunc (si.params , main_func_body, main_func->ret_type , si.buffer_map ,
256- main_func->attrs ));
313+ // We dont need attrs of PrimFunc that might include non printable attrs such as target
314+ // for unit tests where emit_tvmscript_printable_ is to be used.
315+ if (!emit_tvmscript_printable_) {
316+ main_func =
317+ PrimFunc (si.params , main_func_body, main_func->ret_type , si.buffer_map , main_func->attrs );
318+ main_func = WithAttr (main_func, tvm::attr::kPoolArgs , si.allocated_pool_params );
319+ } else {
320+ main_func =
321+ PrimFunc (si.params , main_func_body, main_func->ret_type , si.buffer_map , DictAttrs ());
322+ }
323+ module_->Update (gv, main_func);
324+ if (!emit_tvmscript_printable_) {
325+ return WithAttr (this ->module_ , tvm::attr::kPoolArgs , si.allocated_pool_params );
326+ }
257327 return this ->module_ ;
258328}
259329
260330namespace transform {
261331
262332tvm::transform::Pass ConvertPoolAllocationsToOffsets (
263- const Map<tir::Stmt, PoolAllocation>& pool_allocations) {
333+ const Map<tir::Stmt, PoolAllocation>& pool_allocations,
334+ Bool emit_tvmscript_printable = Bool(false )) {
264335 auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) {
265- return Downcast<IRModule>(PoolAllocationToOffsetConverter (m, pool_allocations)());
336+ return Downcast<IRModule>(PoolAllocationToOffsetConverter (
337+ m, pool_allocations, emit_tvmscript_printable->value != 0 )());
266338 };
267339 return tvm::transform::CreateModulePass (pass_func, 0 , " tir.usmp.ConvertPoolAllocationsToOffsets" ,
268340 {});
0 commit comments