@@ -135,20 +135,91 @@ Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) {
135135 return rewriter (body);
136136}
137137
138+ class SubroutineCallRewriter : public StmtExprMutator {
139+ public:
140+ static Optional<Stmt> Apply (const Map<GlobalVar, String>& packed_func_methods, Stmt stmt) {
141+ SubroutineCallRewriter rewriter (packed_func_methods);
142+ stmt = rewriter.VisitStmt (std::move (stmt));
143+ if (rewriter.made_change_ ) {
144+ return stmt;
145+ } else {
146+ return NullOpt;
147+ }
148+ }
149+
150+ private:
151+ explicit SubroutineCallRewriter (const Map<GlobalVar, String>& packed_func_methods)
152+ : packed_func_methods(packed_func_methods) {}
153+
154+ PrimExpr VisitExpr_ (const CallNode* op) override {
155+ auto node = Downcast<Call>(StmtExprMutator::VisitExpr_ (op));
156+
157+ if (auto * gvar_ptr = node->op .as <GlobalVarNode>()) {
158+ auto gvar = GetRef<GlobalVar>(gvar_ptr);
159+ if (auto symbol = packed_func_methods.Get (gvar)) {
160+ Array<PrimExpr> cpacked_args;
161+ cpacked_args.push_back (tir::StringImm (symbol.value ()));
162+ for (auto arg : node->args ) {
163+ cpacked_args.push_back (arg);
164+ }
165+
166+ // push an empty handle to be compatible with current cpacked convention
167+ cpacked_args.push_back (tir::make_zero (DataType::Handle ()));
168+ made_change_ = true ;
169+ return tir::Call (node->dtype , tir::builtin::tvm_call_cpacked (), cpacked_args);
170+ }
171+ }
172+
173+ return node;
174+ }
175+ const Map<GlobalVar, String>& packed_func_methods;
176+ bool made_change_{false };
177+ };
178+
138179inline Stmt MakeAssertEQ (PrimExpr lhs, PrimExpr rhs, std::string msg) {
139180 return AssertStmt (lhs == rhs, tvm::tir::StringImm (msg), Evaluate (0 ));
140181}
141182
142- PrimFunc MakePackedAPI (PrimFunc&& func) {
183+ /* \brief Return the global_symbol of the function, if it should be updated
184+ *
185+ * \param func The function to be inspected
186+ *
187+ * \returns The global_symbol to be used for the function at call
188+ * sites, or NullOpt if the function is to remain unchanged.
189+ */
190+ Optional<String> RequiresPackedAPI (const PrimFunc& func) {
191+ // A function with an explicit calling convention has already been
192+ // lowered, and should not be modified.
193+ if (auto opt = func->GetAttr <Integer>(tvm::attr::kCallingConv )) {
194+ if (CallingConv (opt.value ()->value ) != CallingConv::kDefault ) {
195+ return NullOpt;
196+ }
197+ }
198+
199+ // Internal function calls do not need the PackedFunc API
143200 auto global_symbol = func->GetAttr <String>(tvm::attr::kGlobalSymbol );
144- ICHECK (global_symbol) << " MakePackedAPI: Expect PrimFunc to have the global_symbol attribute" ;
201+ if (!global_symbol.defined ()) {
202+ return NullOpt;
203+ }
145204
146- auto target = func->GetAttr <Target>(tvm::attr::kTarget );
147- ICHECK (target.defined ()) << " MakePackedAPI: Require the target attribute" ;
148- int target_device_type = target.value ()->GetTargetDeviceType ();
205+ return global_symbol;
206+ }
149207
208+ PrimFunc MakePackedAPI (PrimFunc func) {
209+ auto global_symbol = RequiresPackedAPI (func);
210+ if (!global_symbol.defined ()) {
211+ return func;
212+ }
150213 std::string name_hint = global_symbol.value ();
151214
215+ Target target = [&]() {
216+ auto opt = func->GetAttr <Target>(tvm::attr::kTarget );
217+ ICHECK (opt) << " MakePackedAPI required the function to be annotated with tvm::attr::kTarget ("
218+ << tvm::attr::kTarget << " ), but the function only has attributes " << func->attrs ;
219+ return opt.value ();
220+ }();
221+ int target_device_type = target->GetTargetDeviceType ();
222+
152223 auto * func_ptr = func.CopyOnWrite ();
153224 const Stmt nop = Evaluate (0 );
154225 int num_args = static_cast <int >(func_ptr->params .size ());
@@ -292,39 +363,55 @@ PrimFunc MakePackedAPI(PrimFunc&& func) {
292363 func_ptr->params = args;
293364
294365 Array<Var> undefined = UndefinedVars (func_ptr->body , func_ptr->params );
295- ICHECK_EQ (undefined.size (), 0 ) << " In PrimFunc " << global_symbol << " variables " << undefined
366+ ICHECK_EQ (undefined.size (), 0 ) << " In PrimFunc " << name_hint << " variables " << undefined
296367 << " are used, but are not passed in as API arguments" ;
297368
298369 func_ptr->buffer_map = Map<Var, Buffer>();
299370 func_ptr->checked_type_ = func_ptr->func_type_annotation ();
300371 func_ptr->ret_type = PrimType (DataType::Int (32 ));
301372
302373 // return the function.
303- return std::move ( func) ;
374+ return func;
304375}
305376
306377namespace transform {
307378
308379Pass MakePackedAPI () {
309- auto pass_func = [](IRModule m, PassContext ctx) {
310- IRModuleNode* mptr = m.CopyOnWrite ();
311- std::vector<std::pair<GlobalVar, PrimFunc>> updates;
380+ auto pass_func = [](IRModule mod, PassContext ctx) {
381+ Map<GlobalVar, String> packed_func_methods;
382+ for (const auto & [gvar, base_func] : mod->functions ) {
383+ if (auto opt = base_func.as <PrimFunc>()) {
384+ auto prim_func = opt.value ();
385+ if (auto global_symbol = RequiresPackedAPI (prim_func)) {
386+ packed_func_methods.Set (gvar, global_symbol.value ());
387+ }
388+ }
389+ }
390+
391+ IRModuleNode* mptr = mod.CopyOnWrite ();
392+ IRModule updates;
312393
313- for (const auto & kv : mptr->functions ) {
314- if (auto opt = kv. second .as <PrimFunc>()) {
394+ for (const auto & [gvar, base_func] : mptr->functions ) {
395+ if (auto opt = base_func .as <PrimFunc>()) {
315396 auto func = opt.value ();
316- if (func->GetAttr <Integer>(tvm::attr::kCallingConv , Integer (CallingConv::kDefault )) ==
317- CallingConv::kDefault ) {
318- auto updated_func = MakePackedAPI (std::move (func));
319- updates.push_back ({kv.first , updated_func});
397+ auto orig_func = func;
398+
399+ if (auto body = SubroutineCallRewriter::Apply (packed_func_methods, func->body )) {
400+ func.CopyOnWrite ()->body = body.value ();
401+ }
402+
403+ func = MakePackedAPI (std::move (func));
404+
405+ if (!func.same_as (orig_func)) {
406+ updates->Add (gvar, func);
320407 }
321408 }
322409 }
323410
324- for ( const auto & pair : updates) {
325- mptr-> AddUnchecked (pair. first , pair. second );
411+ if ( updates-> functions . size () ) {
412+ mod. CopyOnWrite ()-> Update (updates );
326413 }
327- return m ;
414+ return mod ;
328415 };
329416
330417 return tvm::transform::CreateModulePass (pass_func, 0 , " tir.MakePackedAPI" , {});
0 commit comments