Skip to content

Commit 7dca655

Browse files
wweiczhiics
authored andcommitted
[Relay][Pass] Add pass to remove unused functions in relay module (#4334)
* [Relay][Pass] Add pass to remove unused functions in relay module * Add tests * Fix lint * Fix visit order * Add pass argument * Fix
1 parent 5b9f459 commit 7dca655

File tree

4 files changed

+228
-0
lines changed

4 files changed

+228
-0
lines changed

python/tvm/relay/transform.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,22 @@ def BackwardFoldScaleAxis():
297297
"""
298298
return _transform.BackwardFoldScaleAxis()
299299

300+
def RemoveUnusedFunctions(entry_functions=None):
301+
"""Remove unused global relay functions in a relay module.
302+
303+
Parameters
304+
----------
305+
entry_functions: list[string]
306+
The set of entry functions to start from.
307+
308+
Returns
309+
-------
310+
ret : tvm.relay.Pass
311+
The registered pass to remove unused functions.
312+
"""
313+
if entry_functions is None:
314+
entry_functions = ['main']
315+
return _transform.RemoveUnusedFunctions(entry_functions)
300316

301317
def ForwardFoldScaleAxis():
302318
"""Fold the scaling of axis into weights of conv2d/dense.

src/relay/backend/vm/compiler.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ namespace transform {
5454

5555
Pass LambdaLift();
5656
Pass InlinePrimitives();
57+
Pass RemoveUnusedFunctions(Array<tvm::Expr> entry_functions);
5758

5859
Pass ManifestAlloc(Target target_host) {
5960
auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc");
@@ -863,6 +864,8 @@ void VMCompiler::Compile(Module mod,
863864

864865
Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
865866
Array<Pass> pass_seqs;
867+
Array<tvm::Expr> entry_functions{tvm::Expr{"main"}};
868+
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
866869
// Run all dialect legalization passes.
867870
pass_seqs.push_back(relay::qnn::transform::Legalize());
868871

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
* \file tvm/relay/backend/vm/remove_unused_funcs.cc
23+
* \brief Remove unused global relay functions in a relay module.
24+
*/
25+
26+
#include <tvm/relay/expr.h>
27+
#include <tvm/relay/expr_functor.h>
28+
#include <tvm/logging.h>
29+
#include <tvm/relay/analysis.h>
30+
#include <tvm/relay/transform.h>
31+
#include <tvm/runtime/vm.h>
32+
#include <iostream>
33+
#include <unordered_set>
34+
#include <vector>
35+
36+
namespace tvm {
37+
namespace relay {
38+
namespace vm {
39+
40+
/**
41+
* \brief Detects all the functions that can be possibly called by entry function.
42+
*/
43+
struct CallTracer : ExprVisitor {
44+
Module module_;
45+
46+
// Record the names of all encountered functions
47+
std::unordered_set<std::string> called_funcs_;
48+
49+
// Record the expressions that are being visited
50+
std::unordered_set<Expr, NodeHash, NodeEqual> visiting_;
51+
52+
explicit CallTracer(const Module& module)
53+
: module_{module},
54+
called_funcs_{},
55+
visiting_{} {}
56+
57+
void VisitExpr_(const CallNode* call_node) final {
58+
Expr op = call_node->op;
59+
for (auto param : call_node->args) {
60+
VisitExpr(param);
61+
}
62+
if (auto func_node = op.as<FunctionNode>()) {
63+
auto func = GetRef<Function>(func_node);
64+
auto it = visiting_.find(func);
65+
if (it != visiting_.end()) {
66+
return;
67+
}
68+
visiting_.insert(func);
69+
VisitExpr(func);
70+
} else if (auto global = op.as<GlobalVarNode>()) {
71+
called_funcs_.insert(global->name_hint);
72+
auto func = module_->Lookup(global->name_hint);
73+
auto it = visiting_.find(func);
74+
if (it != visiting_.end()) {
75+
return;
76+
}
77+
visiting_.insert(func);
78+
VisitExpr(func);
79+
}
80+
}
81+
82+
std::unordered_set<std::string> Trace(const std::string& entry) {
83+
called_funcs_.insert(entry);
84+
auto main_func = module_->Lookup(entry);
85+
VisitExpr(main_func);
86+
return called_funcs_;
87+
}
88+
};
89+
90+
/*!
91+
* \brief Remove functions that are not used.
92+
*
93+
* \param module The Relay module.
94+
* \param entry_funcs The set of functions that can be entry function.
95+
*
96+
* \return The module with dead functions removed.
97+
*/
98+
Module RemoveUnusedFunctions(const Module& module,
99+
Array<tvm::Expr> entry_funcs) {
100+
std::unordered_set<std::string> called_funcs{};
101+
for (auto entry : entry_funcs) {
102+
auto* str_name = entry.as<ir::StringImm>();
103+
auto funcs = CallTracer(module).Trace(str_name->value);
104+
called_funcs.insert(funcs.cbegin(), funcs.cend());
105+
}
106+
auto existing_functions = module->functions;
107+
for (auto f : existing_functions) {
108+
auto it = called_funcs.find(f.first->name_hint);
109+
if (it == called_funcs.end()) {
110+
module->Remove(f.first);
111+
}
112+
}
113+
return module;
114+
}
115+
116+
} // namespace vm
117+
118+
namespace transform {
119+
120+
Pass RemoveUnusedFunctions(Array<tvm::Expr> entry_functions) {
121+
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
122+
[=](Module m, PassContext pc) {
123+
return relay::vm::RemoveUnusedFunctions(m, entry_functions);
124+
};
125+
return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions", {});
126+
}
127+
128+
TVM_REGISTER_API("relay._transform.RemoveUnusedFunctions")
129+
.set_body_typed(RemoveUnusedFunctions);
130+
131+
} // namespace transform
132+
133+
} // namespace relay
134+
} // namespace tvm
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import tvm
18+
from tvm import relay
19+
from tvm.relay import transform
20+
from tvm.relay.prelude import Prelude
21+
22+
def test_remove_all_prelude_functions():
23+
mod = relay.Module()
24+
p = Prelude(mod)
25+
x = relay.var("x", shape=(1, 16))
26+
mod["main"] = relay.Function([x], x)
27+
mod = relay.transform.RemoveUnusedFunctions()(mod)
28+
l = set([x[0].name_hint for x in mod.functions.items()])
29+
assert l == set(['main'])
30+
31+
def test_remove_all_prelude_functions_but_referenced_functions():
32+
mod = relay.Module()
33+
p = Prelude(mod)
34+
x = relay.var("x", shape=(1, 16))
35+
id_func = relay.Function([x], x)
36+
id_name = relay.GlobalVar('id_func')
37+
mod[id_name] = id_func
38+
39+
mod["main"] = relay.Function([x], id_name(x))
40+
mod = relay.transform.RemoveUnusedFunctions()(mod)
41+
l = set([x[0].name_hint for x in mod.functions.items()])
42+
assert l == set(['id_func', 'main'])
43+
44+
def test_keep_only_referenced_prelude_functions():
45+
mod = relay.Module()
46+
p = Prelude(mod)
47+
l = p.nil()
48+
for i in [4, 3, 2, 1, 0]:
49+
l = p.cons(relay.const(i), l)
50+
body = p.hd(p.tl(p.tl(l)))
51+
mod["main"] = relay.Function([], body)
52+
mod = relay.transform.RemoveUnusedFunctions()(mod)
53+
l = set([x[0].name_hint for x in mod.functions.items()])
54+
assert l == set(['tl', 'hd', 'main'])
55+
56+
def test_multiple_entry_functions():
57+
mod = relay.Module()
58+
p = Prelude(mod)
59+
l = p.nil()
60+
for i in [4, 3, 2, 1, 0]:
61+
l = p.cons(relay.const(i), l)
62+
body = p.hd(p.tl(p.tl(l)))
63+
mod["main1"] = relay.Function([], body)
64+
65+
x = relay.var("x", shape=(1, 16))
66+
id_func = relay.Function([x], x)
67+
id_name = relay.GlobalVar('id_func')
68+
mod[id_name] = id_func
69+
mod["main2"] = relay.Function([x], id_name(x))
70+
mod = relay.transform.RemoveUnusedFunctions(['main1', 'main2'])(mod)
71+
l = set([x[0].name_hint for x in mod.functions.items()])
72+
assert l == set(['tl', 'hd', 'main2', 'id_func', 'main1'])
73+
74+
if __name__ == '__main__':
75+
pytest.main()

0 commit comments

Comments
 (0)