Skip to content

Commit 875217c

Browse files
authored
[TIR] Restrict tir.transform.InstallDebugSpans to host functions (#14943)
* [Bugfix] Avoid symbol conflicts in MakePackedAPI/MakeUnpackedAPI PRs #14913 and #14914 made analogous changes to `MakePackedAPI` and `MakeUnpackedAPI` to handle subroutine calls. Both PRs introduced the same symbol, `tvm::tir::SubroutineCallRewriter`, a local utility to update internal calls to a modified function. While each PR passed CI individually, and was therefore able to merge, having both changes caused a duplicate symbol. This commit updates `MakePackedAPI` and `MakeUnpackedAPI` to place their local utilities into anonymous namespaces, avoiding the conflict. * [Target] Added utility method TargetNode::HasKey() This utility method makes it easier to determine if a target contains a specific key. * [TIR] Added utility method tvm::tir::IsHostFunc(const PrimFunc&) For modules that contain both host and device functions, this utility function checks whether a given PrimFunc is a host function, based on the target annotation. * [TIR] Restrict InstallDebugSpans to host functions Previously, the `tir.InstallDebugSpans` pass required the module to contain only a single PrimFunc. This commit relaxes the requirement, to require a single host-side PrimFunc, and to ignore any other device-side functions.
1 parent 81056cc commit 875217c

File tree

2 files changed

+73
-13
lines changed

2 files changed

+73
-13
lines changed

src/tir/transforms/install_debug_spans.cc

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <utility>
3232

3333
#include "../../relay/printer/tir_text_printer_debug.h"
34+
#include "ir_utils.h"
3435

3536
namespace tvm {
3637
namespace tir {
@@ -128,19 +129,30 @@ TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS
128129
namespace transform {
129130

130131
Pass InstallDebugSpans() {
131-
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
132-
ICHECK(m->functions.size() == 1)
133-
<< "Debug info can only be added to IRModules with a single function";
134-
// There is known to be only 1 function in the module at this point
135-
auto entry = m->functions.begin();
136-
auto name = std::get<0>(*entry)->name_hint;
137-
auto* n = f.CopyOnWrite();
138-
139-
n->body = DebugInfoInstaller::InstallInfo(std::move(name), std::move(f->body));
140-
141-
return f;
132+
auto pass_func = [](IRModule mod, PassContext ctx) {
133+
Map<GlobalVar, PrimFunc> external_host_functions;
134+
for (const auto& [gvar, base_func] : mod->functions) {
135+
if (auto opt = base_func.as<PrimFunc>()) {
136+
auto prim_func = opt.value();
137+
if (IsHostFunc(prim_func).value_or(false) &&
138+
prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
139+
external_host_functions.Set(gvar, prim_func);
140+
}
141+
}
142+
}
143+
144+
ICHECK_EQ(external_host_functions.size(), 1)
145+
<< "Debug info can only be added to IRModules with a single host function";
146+
147+
for (auto [gvar, prim_func] : external_host_functions) {
148+
auto name = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
149+
prim_func.CopyOnWrite()->body = DebugInfoInstaller::InstallInfo(name, prim_func->body);
150+
mod.CopyOnWrite()->Update(gvar, prim_func);
151+
}
152+
153+
return mod;
142154
};
143-
return CreatePrimFuncPass(pass_func, 0, "tir.InstallDebugSpans", {});
155+
return tvm::transform::CreateModulePass(pass_func, 0, "tir.InstallDebugSpans", {});
144156
}
145157

146158
TVM_REGISTER_GLOBAL("tir.transform.InstallDebugSpans").set_body_typed(InstallDebugSpans);

tests/python/tir/test_debug_info.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@ class MyModule:
4646
@T.prim_func
4747
def main(a: T.handle, b: T.handle):
4848
# We exchange data between function by handles, which are similar to pointer.
49-
T.func_attr({"global_symbol": "main", "tir.noalias": True})
49+
T.func_attr(
50+
{
51+
"global_symbol": "main",
52+
"tir.noalias": True,
53+
"target": T.target("llvm"),
54+
}
55+
)
5056
# Create buffer from handles.
5157
A = T.match_buffer(a, (8,), dtype="float32")
5258
B = T.match_buffer(b, (8,), dtype="float32")
@@ -83,6 +89,48 @@ def find_span(m):
8389
assert span_after.line == 4
8490

8591

92+
def test_tir_debug_info_with_subroutine():
93+
"""Like test_tir_debug_info, but with a TIR subroutine
94+
95+
The current InstallDebugSpans applies to a single PrimFunc. This
96+
test verifies that the existence of device-side subroutines
97+
98+
"""
99+
100+
def find_span(m):
101+
func = next(m.functions.values())
102+
return func.body.block.body.span
103+
104+
@tvm.script.ir_module
105+
class module_before:
106+
@T.prim_func
107+
def main(a: T.handle, b: T.handle):
108+
T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": T.target("llvm")})
109+
A = T.match_buffer(a, (8,), dtype="float32")
110+
B = T.match_buffer(b, (8,), dtype="float32")
111+
for i in range(8):
112+
with T.block("B"):
113+
vi = T.axis.spatial(8, i)
114+
module_before.subroutine(T.address_of(A[vi]), T.address_of(B[vi]))
115+
116+
@T.prim_func
117+
def subroutine(a_ptr: T.handle("float32"), b_ptr: T.handle("float32")):
118+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
119+
A = T.decl_buffer(1, "float32", data=a_ptr)
120+
B = T.decl_buffer(1, "float32", data=b_ptr)
121+
B[0] = A[1] + 1.0
122+
123+
span_before = find_span(module_before)
124+
assert span_before is None
125+
126+
module_after = tir.transform.InstallDebugSpans()(module_before)
127+
span_after = find_span(module_after)
128+
129+
# Check that the module name has been added and a line number is present
130+
assert span_after.source_name.name == "main.tir"
131+
assert span_after.line == 4
132+
133+
86134
def test_llvm_ir_debug_info():
87135
"""
88136
Check that the right amount of debug locations are present

0 commit comments

Comments
 (0)