Skip to content

Commit f49b1f8

Browse files
committed
[TIR][Transform] Implement InlinePrivateFunctions
The functionality to express a call from one `PrimFunc` to another was introduced in apache#14889. While this was initially planned to be supported at codegen for all targets (see apache#15835), this resulted in breakage on some backends (see apache#16033). After discussion, the plan was changed to support TIR inlining, which would enable the same high-level functionality in TIR without requiring immediate low-level support across all codegens. This commit implements and tests a new IRModule transform `InlinePrivateFunctions`, which can be used as part of lowering in a follow-up commit. Because this is initially implemented for use quite late in the lowering flow, many constructs are not currently supported. The current implementation has the following restrictions. * `tir::Block` nodes may not occur in the inlined function. Because a subroutine may be called multiple times, inlining of a subroutine that contains `tir::Block` would result in non-unique names. Support of subroutines with `tir::Block` instances will require de-duplication of block names. * The subroutine's callsite must occur within a `tir::Evaluate` block. Because inlining a subroutine inserts the `tir::Stmt` body at the point of use, replacement must occur in a context where a `tir::Stmt` can be returned. Support of subroutines that are called within an expression (e.g. Replacing `func` in `Buf[0] = func(1) + func(2)`) would require hoisting preprocessing done in the subroutine to the parent `tir::Stmt`. * The subroutine may only accept primitive arguments, and must have an empty `buffer_map`. Support of subroutines that are called with `tir::Buffer` or `tir::BufferRegion` arguments would require a way to represent these arguments at the callsite, and substitution of the buffer into the callee. If these unsupported constructs are used, then the inlining of those functions is skipped. This commit includes unit tests for these unsupported constructs, to validate that `InlinePrivateFunctions` produces well-formed output even when they are present.
1 parent b484142 commit f49b1f8

File tree

4 files changed

+544
-0
lines changed

4 files changed

+544
-0
lines changed

include/tvm/tir/transform.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,13 @@ TVM_DLL Pass BF16StorageLegalize();
414414
*/
415415
TVM_DLL Pass FP8StorageLegalize();
416416

417+
/*!
418+
* \brief Inline calls to private functions
419+
*
420+
* \return The pass.
421+
*/
422+
TVM_DLL Pass InlinePrivateFunctions();
423+
417424
/*!
418425
* \brief Rewrite the pointer content type of arguments,
419426
* as well as Alloc internal to the function to use

python/tvm/tir/transform/transform.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,17 @@ def StorageRewrite():
230230
return _ffi_api.StorageRewrite() # type: ignore
231231

232232

233+
def InlinePrivateFunctions():
234+
"""Inline calls to private functions
235+
236+
Returns
237+
-------
238+
fpass : tvm.transform.Pass
239+
The result pass
240+
"""
241+
return _ffi_api.InlinePrivateFunctions() # type: ignore
242+
243+
233244
def PointerValueTypeRewrite():
234245
"""
235246
Rewrite the pointer content type of arguments, as well as Alloc internal to the function to use
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file inline_private_functions.cc
22+
* \brief Inline private functions to their callsite
23+
*/
24+
#include <tvm/runtime/registry.h>
25+
#include <tvm/tir/analysis.h>
26+
#include <tvm/tir/builtin.h>
27+
#include <tvm/tir/op.h>
28+
#include <tvm/tir/stmt.h>
29+
#include <tvm/tir/stmt_functor.h>
30+
#include <tvm/tir/transform.h>
31+
32+
namespace tvm {
33+
namespace tir {
34+
namespace transform {
35+
36+
namespace {
37+
38+
template <typename T>
39+
using PSet = std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>;
40+
41+
template <typename T, typename U>
42+
using PMap = std::unordered_map<T, U, ObjectPtrHash, ObjectPtrEqual>;
43+
44+
PMap<GlobalVar, PSet<GlobalVar>> CollectCallMap(const IRModule& mod) {
45+
struct Visitor : StmtExprVisitor {
46+
GlobalVar current;
47+
PMap<GlobalVar, PSet<GlobalVar>> caller_lookup;
48+
49+
void VisitExpr_(const CallNode* op) {
50+
if (auto gvar = op->op.as<GlobalVar>()) {
51+
caller_lookup[gvar.value()].insert(current);
52+
}
53+
StmtExprVisitor::VisitExpr_(op);
54+
}
55+
} visitor;
56+
57+
for (const auto& [gvar, base_func] : mod->functions) {
58+
if (auto prim_func = base_func.as<PrimFuncNode>()) {
59+
visitor.current = gvar;
60+
visitor(prim_func->body);
61+
}
62+
}
63+
64+
return visitor.caller_lookup;
65+
}
66+
67+
PSet<GlobalVar> CollectRecursiveFunctions(const IRModule& mod) {
68+
// Collect all direct callers
69+
auto call_map = CollectCallMap(mod);
70+
71+
// Propagate to find all indirect callers
72+
while (true) {
73+
bool made_change = false;
74+
for (const auto& [callee, callers] : call_map) {
75+
for (const auto& caller : callers) {
76+
if (auto it = call_map.find(caller); it != call_map.end()) {
77+
PSet<GlobalVar>& indirect_callers = it->second;
78+
79+
auto res = indirect_callers.insert(callee);
80+
made_change = made_change || res.second;
81+
}
82+
}
83+
}
84+
if (!made_change) {
85+
break;
86+
}
87+
}
88+
89+
// Filter all GlobalVars that can be called by themselves, either
90+
// directly or indirectly.
91+
PSet<GlobalVar> recursive_funcs;
92+
for (const auto& [caller, callees] : call_map) {
93+
if (callees.count(caller)) {
94+
recursive_funcs.insert(caller);
95+
}
96+
}
97+
return recursive_funcs;
98+
}
99+
100+
Map<GlobalVar, PrimFunc> CollectInlinablePrimFuncs(const IRModule& mod) {
101+
auto recursive_functions = CollectRecursiveFunctions(mod);
102+
103+
Map<GlobalVar, PrimFunc> output;
104+
for (const auto& [gvar, base_func] : mod->functions) {
105+
if (auto opt = base_func.as<PrimFunc>()) {
106+
auto prim_func = opt.value();
107+
108+
// Only inline private functions. Externally-exposed functions
109+
// must be preserved so to avoid breaking callsites outside of
110+
// the IRModule.
111+
bool is_exposed = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
112+
113+
// We do not currently implement any analysis for termination of
114+
// a function. If a recursive function requires runtime checks
115+
// in order to terminate, we would keep inlining until the
116+
// recursive visits segfault.
117+
bool is_recursive = recursive_functions.count(gvar);
118+
119+
// We do not currently support inlining of functions that accept
120+
// buffer arguments.
121+
bool has_buffer_arguments = prim_func->buffer_map.size();
122+
123+
// We do not currently support inlining of schedulable TIR
124+
// functions. To support this use case, repeated names in
125+
// `tir::Block` nodes resulting from multiple calls to the same
126+
// inlined function will need to be de-duplicated.
127+
bool has_block_node = prim_func->body.as<BlockRealizeNode>();
128+
129+
if (!is_exposed && !is_recursive && !has_buffer_arguments && !has_block_node) {
130+
output.Set(gvar, prim_func);
131+
}
132+
}
133+
}
134+
135+
return output;
136+
}
137+
138+
class PrimFuncInliner : StmtExprMutator {
139+
public:
140+
explicit PrimFuncInliner(Map<GlobalVar, PrimFunc> inlinable_funcs)
141+
: inlinable_funcs_(inlinable_funcs) {
142+
for (const auto& [gvar, callee] : inlinable_funcs_) {
143+
removable_funcs_.insert(gvar);
144+
}
145+
}
146+
147+
PrimFunc VisitFunc(PrimFunc func) {
148+
current_target_ = func->GetAttr<Target>(tvm::attr::kTarget);
149+
auto new_body = VisitStmt(func->body);
150+
current_target_ = NullOpt;
151+
152+
if (!new_body.same_as(func->body)) {
153+
func.CopyOnWrite()->body = new_body;
154+
}
155+
156+
return func;
157+
}
158+
159+
PSet<GlobalVar> GetRemovableFunctions() const { return removable_funcs_; }
160+
161+
private:
162+
Stmt VisitStmt_(const EvaluateNode* eval) override {
163+
if (auto call = eval->value.as<CallNode>()) {
164+
if (auto gvar = call->op.as<GlobalVar>()) {
165+
if (auto opt_callee = inlinable_funcs_.Get(gvar.value())) {
166+
auto callee = opt_callee.value();
167+
168+
bool is_same_target = [&]() -> bool {
169+
auto callee_target = callee->GetAttr<Target>(tvm::attr::kTarget);
170+
if (current_target_ && callee_target) {
171+
return callee_target.value()->str() == current_target_.value()->str();
172+
} else {
173+
return true;
174+
}
175+
}();
176+
177+
if (is_same_target) {
178+
Stmt inlined = InlineArguments(gvar.value(), callee, call->args);
179+
return VisitStmt(inlined);
180+
}
181+
}
182+
}
183+
}
184+
185+
return StmtExprMutator::VisitStmt_(eval);
186+
}
187+
188+
PrimExpr VisitExpr_(const CallNode* call) override {
189+
// Any callee that hasn't been inlined at this point must be kept
190+
// in the output IRModule.
191+
if (auto gvar = call->op.as<GlobalVar>()) {
192+
removable_funcs_.erase(gvar.value());
193+
}
194+
return StmtExprMutator::VisitExpr_(call);
195+
}
196+
197+
Stmt InlineArguments(const GlobalVar& gvar, PrimFunc callee, const Array<PrimExpr>& args) const {
198+
CHECK_EQ(callee->params.size(), args.size())
199+
<< "Callee " << gvar << " accepts " << callee->params.size() << " parameters ("
200+
<< callee->params << "), but is called with " << args.size() << " arguments (" << args
201+
<< ")";
202+
203+
ICHECK(callee->buffer_map.empty())
204+
<< "Inlining of PrimFuncs with buffer arguments is not yet supported, "
205+
<< "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map;
206+
207+
Map<Var, ObjectRef> param_map;
208+
for (size_t i = 0; i < callee->params.size(); i++) {
209+
param_map.Set(callee->params[i], args[i]);
210+
}
211+
212+
callee = Specialize(callee, param_map);
213+
214+
return callee->body;
215+
}
216+
217+
// Map from GlobalVar to PrimFuncs which may be inlined.
218+
Map<GlobalVar, PrimFunc> inlinable_funcs_;
219+
220+
/* \brief Set of callees that may be removed
221+
*
222+
* Some constructs may not be inlined (e.g. if the call site occurs
223+
* outside of an Evaluate node). For these cases, the output
224+
* IRModule must still contain the callee.
225+
*/
226+
PSet<GlobalVar> removable_funcs_;
227+
228+
Optional<Target> current_target_ = NullOpt;
229+
};
230+
231+
} // namespace
232+
233+
Pass InlinePrivateFunctions() {
234+
auto pass_func = [](IRModule mod, PassContext ctx) {
235+
auto inlinable_prim_funcs = CollectInlinablePrimFuncs(mod);
236+
237+
if (inlinable_prim_funcs.empty()) {
238+
// Early bail-out if the module has no inlinable PrimFuncs.
239+
return mod;
240+
}
241+
242+
PrimFuncInliner mutator(std::move(inlinable_prim_funcs));
243+
IRModule updates;
244+
245+
for (const auto& [gvar, base_func] : mod->functions) {
246+
if (auto opt = base_func.as<PrimFunc>()) {
247+
auto updated = mutator.VisitFunc(opt.value());
248+
if (!updated.same_as(base_func)) {
249+
updates->Add(gvar, updated);
250+
}
251+
}
252+
}
253+
254+
if (updates->functions.size()) {
255+
auto write_ptr = mod.CopyOnWrite();
256+
write_ptr->Update(updates);
257+
for (const auto& gvar : mutator.GetRemovableFunctions()) {
258+
write_ptr->Remove(gvar);
259+
}
260+
mod = ConvertSSA()(mod);
261+
}
262+
263+
return mod;
264+
};
265+
return tvm::transform::CreateModulePass(pass_func, 0, "tir.InlinePrivateFunctions", {});
266+
}
267+
268+
TVM_REGISTER_GLOBAL("tir.transform.InlinePrivateFunctions").set_body_typed(InlinePrivateFunctions);
269+
270+
} // namespace transform
271+
272+
} // namespace tir
273+
} // namespace tvm

0 commit comments

Comments
 (0)