Skip to content

Commit 762e229

Browse files
committed
[Unity][CodeGen] RunCodegen based on externally-exposed functions
Prior to this commit, `relax.transform.RunCodegen` required a list of entry functions for a module, defaulting to `"main"` if not specified. The list of entry functions is duplicate information that could be inferred from the module, and should not be required from the user. This commit updates `RunCodegen` to treat all externally-exposed functions as entry points, in the same manner as `DeadCodeElimination`. For backwards compatibility, the `entry_functions` argument is still accepted, and is used to augment the list of externally-exposed functions.
1 parent 1f779f7 commit 762e229

File tree

3 files changed

+72
-14
lines changed

3 files changed

+72
-14
lines changed

python/tvm/relax/transform/transform.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,8 @@ def RunCodegen(
574574
The registered pass to remove unused functions.
575575
"""
576576
if entry_functions is None:
577-
entry_functions = ["main"]
577+
entry_functions = []
578+
578579
# enable cutlass byoc registries
579580
# pylint: disable=unused-import,import-outside-toplevel
580581
from tvm.contrib import cutlass as _cutlass

src/relax/transform/run_codegen.cc

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#include <iostream>
3030

31+
#include "../../support/ordered_set.h"
3132
#include "utils.h"
3233

3334
namespace tvm {
@@ -39,12 +40,39 @@ class CodeGenRunner : ExprMutator {
3940

4041
explicit CodeGenRunner(IRModule mod) : ExprMutator(mod) {}
4142

42-
IRModule Run(Optional<Map<String, OptionMap>> target_options, Array<String> entry_functions) {
43+
IRModule Run(Optional<Map<String, OptionMap>> target_options,
44+
Array<String> entry_function_names) {
4345
IRModule mod = builder_->GetContextIRModule();
44-
for (const String& entry_func_name : entry_functions) {
45-
auto entry_func = mod->Lookup(entry_func_name);
46-
auto gvar = mod->GetGlobalVar(entry_func_name);
47-
builder_->UpdateFunction(gvar, Downcast<BaseFunc>(VisitExpr(entry_func)));
46+
47+
support::OrderedSet<GlobalVar> entry_functions;
48+
// Any user-provided functions are treated as entry functions.
49+
for (const auto& name : entry_function_names) {
50+
entry_functions.insert(mod->GetGlobalVar(name));
51+
}
52+
53+
// In addtion, any externally-exposed function that does not
54+
// belong to a specific codegen may be an entry function. These
55+
// are added in alphabetical order, to ensure consistent order of
56+
// evaluation for debug/test purposes.
57+
{
58+
std::vector<GlobalVar> attr_entry_functions;
59+
for (const auto& [gv, func] : mod->functions) {
60+
if (func->GetLinkageType() == LinkageType::kExternal &&
61+
!func->GetAttr<String>(attr::kCodegen) && func->IsInstance<relax::FunctionNode>()) {
62+
attr_entry_functions.push_back(gv);
63+
}
64+
}
65+
std::sort(attr_entry_functions.begin(), attr_entry_functions.end(),
66+
[](const auto& gvar_a, const auto& gvar_b) {
67+
return gvar_a->name_hint > gvar_b->name_hint;
68+
});
69+
for (const auto& gvar : attr_entry_functions) {
70+
entry_functions.insert(gvar);
71+
}
72+
}
73+
74+
for (const auto& gvar : entry_functions) {
75+
builder_->UpdateFunction(gvar, Downcast<BaseFunc>(VisitExpr(mod->Lookup(gvar))));
4876
}
4977

5078
auto ext_mods = InvokeCodegen(mod, target_options.value_or({}));
@@ -65,7 +93,7 @@ class CodeGenRunner : ExprMutator {
6593
}
6694

6795
// TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better way to handle this.
68-
return DeadCodeElimination(out_mod, entry_functions);
96+
return DeadCodeElimination(out_mod, entry_function_names);
6997
}
7098

7199
using ExprMutator::VisitExpr_;

tests/python/relax/test_transform_codegen_pass.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,21 +48,21 @@
4848
dev = tvm.cuda()
4949

5050

51-
def check_executable(exec, dev, inputs, expected):
51+
def check_executable(exec, dev, inputs, expected, entry_func_name):
5252
vm = relax.VirtualMachine(exec, dev)
53-
out = vm["main"](*inputs)
53+
out = vm[entry_func_name](*inputs)
5454
tvm.testing.assert_allclose(out.numpy(), expected.numpy(), atol=1e-5, rtol=1e-5)
5555

5656

57-
def check_roundtrip(exec0, dev, inputs, expected):
57+
def check_roundtrip(exec0, dev, inputs, expected, entry_func_name="main"):
5858
exec0.mod.export_library("exec.so")
5959
exec1 = tvm.runtime.load_module("exec.so")
6060
os.remove("exec.so")
6161
assert exec0.stats() == exec1["stats"]()
6262
assert exec0.as_text() == exec1["as_text"]()
6363

64-
check_executable(exec0, dev, inputs, expected)
65-
check_executable(exec1, dev, inputs, expected)
64+
check_executable(exec0, dev, inputs, expected, entry_func_name)
65+
check_executable(exec1, dev, inputs, expected, entry_func_name)
6666

6767

6868
def gen_ground_truth(mod, target, dev, inputs):
@@ -113,10 +113,17 @@ def setup_test():
113113
return mod, inputs, expected
114114

115115

116+
entry_func_name = tvm.testing.parameter("main", "func")
117+
118+
116119
@tvm.testing.requires_gpu
117-
def test_tensorrt_only():
120+
def test_tensorrt_only(entry_func_name):
118121
mod, inputs, expected = setup_test()
119122

123+
if entry_func_name != "main":
124+
mod[entry_func_name] = mod
125+
del mod["main"]
126+
120127
# Define patterns that we want to offload to byoc
121128
# This test will offload entire model
122129
# Thus, define patterns for both `multiply` and `add` ops
@@ -135,7 +142,7 @@ def test_tensorrt_only():
135142

136143
ex0 = relax.build(new_mod, target, params={})
137144
# Sanity check for the correctness and roundtrip
138-
check_roundtrip(ex0, dev, inputs, expected)
145+
check_roundtrip(ex0, dev, inputs, expected, entry_func_name)
139146

140147

141148
@tvm.testing.requires_gpu
@@ -248,6 +255,28 @@ def test_multiple_calls_same_extern():
248255
tvm.ir.assert_structural_equal(mod["main"], Conv2dx2_after["main"])
249256

250257

258+
def test_default_entry_func():
259+
"""The entry function is not necessarily named "main"
260+
261+
Like `test_multiple_calls_same_extern`, but the main function is
262+
named "func".
263+
"""
264+
before_with_main = Conv2dx2
265+
after_with_main = relax.transform.RunCodegen()(before_with_main)
266+
267+
def rename_main(mod):
268+
mod = mod.clone()
269+
mod["func"] = mod["main"].with_attr("global_symbol", "func")
270+
del mod["main"]
271+
return mod
272+
273+
before_with_func = rename_main(before_with_main)
274+
expected_with_func = rename_main(after_with_main)
275+
after_with_func = relax.transform.RunCodegen()(before_with_func)
276+
277+
tvm.ir.assert_structural_equal(expected_with_func["func"], after_with_func["func"])
278+
279+
251280
def test_dynamic_shape():
252281
import tvm.relax.backend.contrib.cublas
253282

0 commit comments

Comments
 (0)