Skip to content

Commit 41a616f

Browse files
authored
[TIR] Handle subroutine calls in MakePackedAPI (#14913)
* [TIR] MakePackedAPI, handle missing kGlobalSymbol Previously, `MakePackedAPI` required all functions to have the `kGlobalSymbol` attribute. This commit updates the behavior such that `MakePackedAPI` only modifies PrimFuncs that have the `kGlobalSymbol` attribute, and passes through any other PrimFunc unmodified. * [TIR] Update calls to externally-exposed subroutines in MakePackedAPI When a function is updated to use the `PackedFunc` API, any calls made to that function from elsewhere in the `IRModule` should be updated as well. * Bugfix, don't update the callsite unless the callee is also updated
1 parent bcf7abb commit 41a616f

File tree

2 files changed

+215
-21
lines changed

2 files changed

+215
-21
lines changed

src/tir/transforms/make_packed_api.cc

Lines changed: 106 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -135,20 +135,91 @@ Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) {
135135
return rewriter(body);
136136
}
137137

138+
class SubroutineCallRewriter : public StmtExprMutator {
139+
public:
140+
static Optional<Stmt> Apply(const Map<GlobalVar, String>& packed_func_methods, Stmt stmt) {
141+
SubroutineCallRewriter rewriter(packed_func_methods);
142+
stmt = rewriter.VisitStmt(std::move(stmt));
143+
if (rewriter.made_change_) {
144+
return stmt;
145+
} else {
146+
return NullOpt;
147+
}
148+
}
149+
150+
private:
151+
explicit SubroutineCallRewriter(const Map<GlobalVar, String>& packed_func_methods)
152+
: packed_func_methods(packed_func_methods) {}
153+
154+
PrimExpr VisitExpr_(const CallNode* op) override {
155+
auto node = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
156+
157+
if (auto* gvar_ptr = node->op.as<GlobalVarNode>()) {
158+
auto gvar = GetRef<GlobalVar>(gvar_ptr);
159+
if (auto symbol = packed_func_methods.Get(gvar)) {
160+
Array<PrimExpr> cpacked_args;
161+
cpacked_args.push_back(tir::StringImm(symbol.value()));
162+
for (auto arg : node->args) {
163+
cpacked_args.push_back(arg);
164+
}
165+
166+
// push an empty handle to be compatible with current cpacked convention
167+
cpacked_args.push_back(tir::make_zero(DataType::Handle()));
168+
made_change_ = true;
169+
return tir::Call(node->dtype, tir::builtin::tvm_call_cpacked(), cpacked_args);
170+
}
171+
}
172+
173+
return node;
174+
}
175+
const Map<GlobalVar, String>& packed_func_methods;
176+
bool made_change_{false};
177+
};
178+
138179
inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
139180
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
140181
}
141182

142-
PrimFunc MakePackedAPI(PrimFunc&& func) {
183+
/* \brief Return the global_symbol of the function, if it should be updated
184+
*
185+
* \param func The function to be inspected
186+
*
187+
* \returns The global_symbol to be used for the function at call
188+
* sites, or NullOpt if the function is to remain unchanged.
189+
*/
190+
Optional<String> RequiresPackedAPI(const PrimFunc& func) {
191+
// A function with an explicit calling convention has already been
192+
// lowered, and should not be modified.
193+
if (auto opt = func->GetAttr<Integer>(tvm::attr::kCallingConv)) {
194+
if (CallingConv(opt.value()->value) != CallingConv::kDefault) {
195+
return NullOpt;
196+
}
197+
}
198+
199+
// Internal function calls do not need the PackedFunc API
143200
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
144-
ICHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
201+
if (!global_symbol.defined()) {
202+
return NullOpt;
203+
}
145204

146-
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
147-
ICHECK(target.defined()) << "MakePackedAPI: Require the target attribute";
148-
int target_device_type = target.value()->GetTargetDeviceType();
205+
return global_symbol;
206+
}
149207

208+
PrimFunc MakePackedAPI(PrimFunc func) {
209+
auto global_symbol = RequiresPackedAPI(func);
210+
if (!global_symbol.defined()) {
211+
return func;
212+
}
150213
std::string name_hint = global_symbol.value();
151214

215+
Target target = [&]() {
216+
auto opt = func->GetAttr<Target>(tvm::attr::kTarget);
217+
ICHECK(opt) << "MakePackedAPI required the function to be annotated with tvm::attr::kTarget ("
218+
<< tvm::attr::kTarget << "), but the function only has attributes " << func->attrs;
219+
return opt.value();
220+
}();
221+
int target_device_type = target->GetTargetDeviceType();
222+
152223
auto* func_ptr = func.CopyOnWrite();
153224
const Stmt nop = Evaluate(0);
154225
int num_args = static_cast<int>(func_ptr->params.size());
@@ -292,39 +363,55 @@ PrimFunc MakePackedAPI(PrimFunc&& func) {
292363
func_ptr->params = args;
293364

294365
Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
295-
ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << global_symbol << " variables " << undefined
366+
ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << name_hint << " variables " << undefined
296367
<< " are used, but are not passed in as API arguments";
297368

298369
func_ptr->buffer_map = Map<Var, Buffer>();
299370
func_ptr->checked_type_ = func_ptr->func_type_annotation();
300371
func_ptr->ret_type = PrimType(DataType::Int(32));
301372

302373
// return the function.
303-
return std::move(func);
374+
return func;
304375
}
305376

306377
namespace transform {
307378

308379
Pass MakePackedAPI() {
309-
auto pass_func = [](IRModule m, PassContext ctx) {
310-
IRModuleNode* mptr = m.CopyOnWrite();
311-
std::vector<std::pair<GlobalVar, PrimFunc>> updates;
380+
auto pass_func = [](IRModule mod, PassContext ctx) {
381+
Map<GlobalVar, String> packed_func_methods;
382+
for (const auto& [gvar, base_func] : mod->functions) {
383+
if (auto opt = base_func.as<PrimFunc>()) {
384+
auto prim_func = opt.value();
385+
if (auto global_symbol = RequiresPackedAPI(prim_func)) {
386+
packed_func_methods.Set(gvar, global_symbol.value());
387+
}
388+
}
389+
}
390+
391+
IRModuleNode* mptr = mod.CopyOnWrite();
392+
IRModule updates;
312393

313-
for (const auto& kv : mptr->functions) {
314-
if (auto opt = kv.second.as<PrimFunc>()) {
394+
for (const auto& [gvar, base_func] : mptr->functions) {
395+
if (auto opt = base_func.as<PrimFunc>()) {
315396
auto func = opt.value();
316-
if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
317-
CallingConv::kDefault) {
318-
auto updated_func = MakePackedAPI(std::move(func));
319-
updates.push_back({kv.first, updated_func});
397+
auto orig_func = func;
398+
399+
if (auto body = SubroutineCallRewriter::Apply(packed_func_methods, func->body)) {
400+
func.CopyOnWrite()->body = body.value();
401+
}
402+
403+
func = MakePackedAPI(std::move(func));
404+
405+
if (!func.same_as(orig_func)) {
406+
updates->Add(gvar, func);
320407
}
321408
}
322409
}
323410

324-
for (const auto& pair : updates) {
325-
mptr->AddUnchecked(pair.first, pair.second);
411+
if (updates->functions.size()) {
412+
mod.CopyOnWrite()->Update(updates);
326413
}
327-
return m;
414+
return mod;
328415
};
329416

330417
return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {});

tests/python/unittest/test_tir_transform_make_packed_api.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import pytest
19+
1820
import tvm
19-
from tvm import te
21+
import tvm.testing
22+
from tvm import te, tir
23+
from tvm.script import tir as T, ir as I
2024
from tvm.driver.build_module import schedule_to_module
2125

2226

@@ -39,7 +43,9 @@ def test_makeapi():
3943
)
4044
)(mod)
4145

42-
f = tvm.tir.transform.MakePackedAPI()(mod)["main"]
46+
before = mod
47+
after = tvm.tir.transform.MakePackedAPI()(mod)
48+
f = after["main"]
4349
assert len(f.params) == 6
4450

4551

@@ -59,6 +65,19 @@ def _find_next(stmt, type):
5965
return stmt
6066

6167

68+
def _find_compute_scope(func):
69+
result = None
70+
71+
def _visitor(stmt):
72+
if isinstance(stmt, tir.AttrStmt) and stmt.attr_key == "compute_scope":
73+
nonlocal result
74+
result = stmt
75+
76+
tir.stmt_functor.post_order_visit(func.body, _visitor)
77+
78+
return result
79+
80+
6281
def test_variable_passed_from_args():
6382
ib = tvm.tir.ir_builder.create()
6483

@@ -143,5 +162,93 @@ def test_device_api_context_implicit_resource_handle():
143162
assert call_extern.args[2] == device_context_in_resource_handle
144163

145164

165+
@pytest.mark.parametrize("use_global_symbol", [True, False])
166+
def test_no_op_when_global_symbol_is_absent(use_global_symbol):
167+
func_attr = {"target": tvm.target.Target("llvm")}
168+
if use_global_symbol:
169+
func_attr["global_symbol"] = "main"
170+
171+
@T.prim_func
172+
def before():
173+
T.func_attr(func_attr)
174+
T.evaluate(0)
175+
176+
after = tvm.tir.transform.MakePackedAPI()(tvm.IRModule.from_expr(before))["main"]
177+
if use_global_symbol:
178+
assert len(after.params) == 6
179+
else:
180+
tvm.ir.assert_structural_equal(before, after)
181+
182+
183+
def test_internal_subroutine_call():
184+
"""Internal subroutines should not use the PackedFunc API
185+
186+
A subroutine without the "global_symbol" attribute is an internal
187+
subroutine, and is not directly exposed to a user of the generated
188+
`runtime.Module`. Therefore, it doesn't need to follow the
189+
PackedFunc API.
190+
"""
191+
192+
@I.ir_module
193+
class before:
194+
@T.prim_func
195+
def main(A: T.Buffer(1, "float32")):
196+
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
197+
before.subroutine(A.data)
198+
199+
@T.prim_func
200+
def subroutine(A_data: T.handle("float32")):
201+
T.func_attr({"target": T.target("llvm")})
202+
T.evaluate(A_data)
203+
204+
after = tvm.tir.transform.MakePackedAPI()(before)
205+
tvm.ir.assert_structural_equal(before["subroutine"], after["subroutine"])
206+
207+
compute_scope = _find_compute_scope(after["main"])
208+
subroutine_call_op = compute_scope.body.value.op
209+
assert isinstance(subroutine_call_op, tvm.ir.GlobalVar), (
210+
f"The main function's CallNode should use the subroutine's GLobalVar as the operation, "
211+
f"but instead has an operation of type {subroutine_call_op}"
212+
)
213+
214+
215+
def test_subroutine_call_to_externally_visible_subroutine():
216+
"""Externally-visible subroutines should use the PackedFunc API
217+
218+
Because the subroutine may be called directly by a user, it must
219+
use the PackedFunc API. Its signature should be updated to the
220+
PackedFunc signature, and call sites should be updated to use
221+
`T.tvm_call_cpacked`.
222+
"""
223+
224+
@I.ir_module
225+
class before:
226+
@T.prim_func
227+
def main(A: T.Buffer(1, "float32")):
228+
T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
229+
before.subroutine(A.data)
230+
231+
@T.prim_func
232+
def subroutine(A_data: T.handle("float32")):
233+
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
234+
T.evaluate(A_data)
235+
236+
after = tvm.tir.transform.MakePackedAPI()(before)
237+
238+
main_compute_scope = _find_compute_scope(after["main"])
239+
assert main_compute_scope is not None
240+
subroutine_compute_scope = _find_compute_scope(after["subroutine"])
241+
assert subroutine_compute_scope is not None
242+
243+
subroutine_call_op = main_compute_scope.body.value.op
244+
assert (
245+
isinstance(subroutine_call_op, tvm.ir.Op)
246+
and subroutine_call_op.name == "tir.tvm_call_cpacked"
247+
), (
248+
f"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', "
249+
f"but instead has an operation of type {subroutine_call_op}"
250+
)
251+
252+
146253
if __name__ == "__main__":
147254
test_makeapi()

0 commit comments

Comments
 (0)