Skip to content

Commit cbb0d21

Browse files
author
Giuseppe Rossini
committed
addressing comments
Change-Id: If51a1f90e81f774ef4340c72871675862629867b
1 parent 71dd730 commit cbb0d21

File tree

10 files changed

+165
-43
lines changed

10 files changed

+165
-43
lines changed

python/tvm/micro/model_library_format.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class UnsupportedInModelLibraryFormatError(Exception):
3535
"""Raised when export_model_library_format does not support the given Module tree."""
3636

3737

38-
def _populate_codegen_dir(mod, codegen_dir: str, model_name: str = None):
38+
def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None):
3939
"""Populate the codegen sub-directory as part of a Model Library Format export.
4040
4141
Parameters
@@ -44,6 +44,9 @@ def _populate_codegen_dir(mod, codegen_dir: str, model_name: str = None):
4444
Module which should be written to codegen_dir.
4545
codegen_dir : str
4646
Path to the codegen directory on disk.
47+
module_name: Optional[str]
48+
Name used to prefix the generated source files
49+
4750
"""
4851
dso_modules = mod._collect_dso_modules()
4952
dso_module_handles = [m.handle.value for m in dso_modules]
@@ -55,7 +58,7 @@ def _populate_codegen_dir(mod, codegen_dir: str, model_name: str = None):
5558

5659
mod_indices = {"lib": 0, "src": 0}
5760
host_codegen_dir = os.path.join(codegen_dir, "host")
58-
lib_name = f"{model_name}_lib" if model_name else "lib"
61+
lib_name = f"{module_name}_lib" if module_name else "lib"
5962

6063
for dso_mod in dso_modules:
6164
if dso_mod.type_key == "c":
@@ -223,7 +226,7 @@ def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, fil
223226
runtime = ["aot"] if is_aot else ["graph"]
224227

225228
metadata = {
226-
"version": 2,
229+
"version": 3,
227230
"model_name": mod.libmod_name,
228231
"export_datetime": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%SZ"),
229232
"memory": _build_memory_map(mod),

python/tvm/relay/build_module.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,20 @@ def _convert_param_map(params):
7070
return inputs
7171

7272

73+
def _is_valid_modname(mod_name):
74+
"""Determine if mod_name is a valid string to use inside
75+
function names
76+
"""
77+
if mod_name:
78+
try:
79+
mod_name.encode("ascii")
80+
return True
81+
except e:
82+
return False
83+
84+
return True
85+
86+
7387
class BuildModule(object):
7488
"""Build an IR module to run on TVM graph executor. This class is used
7589
to expose the `RelayBuildModule` APIs implemented in C++.
@@ -336,6 +350,9 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default"
336350
else:
337351
tophub_context = autotvm.utils.EmptyContext()
338352

353+
if not _is_valid_modname(mod_name):
354+
raise ValueError(mod_name + " contains invalid characters")
355+
339356
with tophub_context:
340357
bld_mod = BuildModule()
341358
executor_config, runtime_mod, params = bld_mod.build(

python/tvm/relay/transform/transform.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -726,17 +726,16 @@ def PartitionGraph():
726726
return _ffi_api.PartitionGraph()
727727

728728

729-
def PartitionGraphWithModName(mod_name):
730-
"""Partition a Relay program into regions that can be executed on different
731-
backends. This version also accepts a mod_name to prefix in front of every
732-
generated function.
729+
def MangleGraph(mod_name):
730+
"""Mangle the functions that are supposed to be externally compiled
731+
with a given mod_name
733732
734733
Returns
735734
-------
736735
ret: tvm.transform.Pass
737-
The registered pass that partitions the Relay program.
736+
The registered pass that mangles the Relay program.
738737
"""
739-
return _ffi_api.PartitionGraphWithModName(mod_name)
738+
return _ffi_api.MangleGraph(mod_name)
740739

741740

742741
def AnnotateTarget(targets, include_non_call_ops=True):

src/relay/backend/aot_executor_codegen.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ class AOTExecutorCodegen : public ExprVisitor {
534534
String run_func_name =
535535
runtime::get_name_mangled(mod_name_, runtime::symbol::tvm_run_func_suffix);
536536
dict_attrs.Set("global_symbol", run_func_name);
537+
dict_attrs.Set("runner_function", Bool(true));
537538

538539
// Make the PrimFunc
539540
return tir::PrimFunc(main_signature_, body, VoidType(), Map<tir::Var, tir::Buffer>(),
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file src/relay/transforms/mangle_graph.cc
22+
*
23+
* \brief Walk the graph to find external compiled function and to
24+
* "mangle" them, i.e., replace func_name with tvmgen_MODNAME_func_name
25+
*/
26+
27+
#include <tvm/ir/error.h>
28+
#include <tvm/relay/expr.h>
29+
#include <tvm/relay/expr_functor.h>
30+
#include <tvm/relay/transform.h>
31+
32+
#include <unordered_map>
33+
34+
#include "../backend/utils.h"
35+
#include "pass_utils.h"
36+
37+
namespace tvm {
38+
namespace relay {
39+
namespace mangling {
40+
41+
class Mangler : public MixedModeMutator {
42+
public:
43+
explicit Mangler(const IRModule& module, std::function<String(String)> mangle_fn)
44+
: module_(module), mangle_fn_(mangle_fn) {}
45+
46+
IRModule Mangle() {
47+
auto glob_funcs = module_->functions;
48+
49+
// Collect function names to be mangled and create
50+
// global mangled variables
51+
for (const auto& pair : glob_funcs) {
52+
if (auto* fn = pair.second.as<FunctionNode>()) {
53+
auto func = GetRef<Function>(fn);
54+
if (func->GetAttr<String>(attr::kCompiler).defined()) {
55+
auto fn_name_mangled = mangle_fn_(pair.first->name_hint);
56+
GlobalVar gvar = GlobalVar(fn_name_mangled);
57+
mangled_gvars_[pair.first->name_hint] = gvar;
58+
}
59+
}
60+
}
61+
62+
// Walk the three and mangle the functions. Then replace compiler functions
63+
// with mangled functions in the module
64+
IRModule new_module;
65+
for (const auto& pair : glob_funcs) {
66+
if (auto* fn = pair.second.as<FunctionNode>()) {
67+
auto func = GetRef<Function>(fn);
68+
func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
69+
func->attrs);
70+
if (func->GetAttr<String>(attr::kCompiler).defined()) {
71+
new_module->Add(mangled_gvars_[pair.first->name_hint], func);
72+
} else {
73+
new_module->Add(pair.first, func);
74+
}
75+
}
76+
}
77+
78+
return new_module;
79+
}
80+
81+
private:
82+
Expr Rewrite_(const CallNode* call, const Expr& post) final {
83+
Expr new_expr = post;
84+
const CallNode* new_call = new_expr.as<CallNode>();
85+
auto op_node = new_call->op.as<GlobalVarNode>();
86+
if (op_node == nullptr || mangled_gvars_.find(op_node->name_hint) == mangled_gvars_.end()) {
87+
return new_expr;
88+
} else {
89+
return Call(mangled_gvars_[op_node->name_hint], new_call->args, new_call->attrs,
90+
new_call->type_args, new_call->span);
91+
}
92+
}
93+
94+
/*!\brief The IRModule used for partitioning. */
95+
IRModule module_;
96+
/*!\brief The function used to mangle operators name */
97+
std::function<String(String)> mangle_fn_;
98+
/*!\brief Tabled used to store (unmangled_var_name, mangled_gvar) pairs*/
99+
std::unordered_map<std::string, GlobalVar> mangled_gvars_;
100+
};
101+
102+
} // namespace mangling
103+
104+
namespace transform {
105+
106+
Pass MangleGraph(String mod_name) {
107+
auto mangle_fn = [mod_name](String name) { return runtime::get_name_mangled(mod_name, name); };
108+
109+
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> mangle_functions =
110+
[=](IRModule m, PassContext pc) { return mangling::Mangler(m, mangle_fn).Mangle(); };
111+
return CreateModulePass(mangle_functions, 0, "MangleFunctions", {});
112+
}
113+
114+
TVM_REGISTER_GLOBAL("relay._transform.MangleGraph").set_body_typed(transform::MangleGraph);
115+
116+
} // namespace transform
117+
118+
} // namespace relay
119+
} // namespace tvm

src/relay/transforms/partition_graph.cc

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,7 @@ struct RegionFuncMetadata {
113113

114114
class Partitioner : public MixedModeMutator {
115115
public:
116-
explicit Partitioner(const IRModule& module, std::function<String(String)> mangle_fn)
117-
: module_(module), mangle_fn_(mangle_fn) {
116+
explicit Partitioner(const IRModule& module) : module_(module) {
118117
for (auto f : module->functions) {
119118
GlobalVar f_var = f.first;
120119
BaseFunc f_func = f.second;
@@ -303,7 +302,7 @@ class Partitioner : public MixedModeMutator {
303302
}
304303

305304
std::string target = end_node->attrs.as<CompilerAttrs>()->compiler;
306-
std::string name = mangle_fn_(target) + "_" + std::to_string(region->GetID());
305+
std::string name = target + "_" + std::to_string(region->GetID());
307306

308307
// Constant propagation
309308
if (!params_bind.empty()) {
@@ -393,8 +392,6 @@ class Partitioner : public MixedModeMutator {
393392

394393
/*!\brief The IRModule used for partitioning. */
395394
IRModule module_;
396-
/*!\brief The function used to mangle operators name */
397-
std::function<String(String)> mangle_fn_;
398395
};
399396

400397
IRModule RemoveDefaultAnnotations(IRModule module) {
@@ -487,7 +484,7 @@ IRModule FlattenTupleOutputs(IRModule module) {
487484

488485
namespace transform {
489486

490-
Pass PartitionGraphCommon(std::function<String(String)> mangle_fn) {
487+
Pass PartitionGraph() {
491488
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> flatten_tuples = [=](IRModule m,
492489
PassContext pc) {
493490
// There could be compiler_end annotations on tuples
@@ -506,35 +503,16 @@ Pass PartitionGraphCommon(std::function<String(String)> mangle_fn) {
506503
return partitioning::RemoveDefaultAnnotations(m);
507504
};
508505

509-
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func = [=](IRModule m,
510-
PassContext pc) {
511-
return partitioning::Partitioner(m, mangle_fn).Partition();
512-
};
506+
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
507+
[=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); };
513508

514509
auto flatten_tuples_pass = CreateModulePass(flatten_tuples, 0, "FlattenNestedTuples", {});
515510
auto remove_default_pass = CreateModulePass(remove_defaults, 0, "RemoveDefaultAnnotations", {});
516511
auto partition_pass = CreateModulePass(part_func, 0, "PartitionGraph", {});
517512
return Sequential({flatten_tuples_pass, remove_default_pass, partition_pass, InferType()});
518513
}
519514

520-
Pass PartitionGraph() {
521-
// Default version. All the function signatures will be "COMPILER_ID"
522-
auto mangle_fn = [](String name) { return name; };
523-
return PartitionGraphCommon(mangle_fn);
524-
}
525-
526-
Pass PartitionGraphWithModName(String mod_name) {
527-
// Mangled version. The user has the opportunity to provide a module name. This
528-
// is useful if more models are being compiled in the same library. The signature
529-
// for every function is "MODNAME_COMPILER_ID". User can compile different models
530-
// with different names, so that more models can coexist in the same library.
531-
auto mangle_fn = [mod_name](String name) { return runtime::get_name_mangled(mod_name, name); };
532-
return PartitionGraphCommon(mangle_fn);
533-
}
534-
535515
TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph").set_body_typed(transform::PartitionGraph);
536-
TVM_REGISTER_GLOBAL("relay._transform.PartitionGraphWithModName")
537-
.set_body_typed(transform::PartitionGraphWithModName);
538516

539517
} // namespace transform
540518

src/target/source/codegen_c_host.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,13 @@ void CodeGenCHost::AddFunction(const PrimFunc& f) {
6363
}
6464

6565
void CodeGenCHost::LinkParameters(Map<String, LinkedParam> params, bool static_linking) {
66-
// TODO(giuseros): remove this once we don't need an external function call to
67-
// retrieve the parameters.
66+
// When we use static linking, the function will be only visible inside the compilation
67+
// unit. This applies to the case in which we are not using an external interpreter
68+
// (i.e., AOT) and the runner function will be anyway in the same compilation unit of
69+
// the look_up_linked_param function.
70+
// TODO(giuseros): once we remove the call to "tvm_lookup_linked_param" in AOT, we won't
71+
// need this anymore.
72+
6873
if (static_linking) {
6974
stream << "static int32_t";
7075
} else {
@@ -392,8 +397,7 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
392397
// Make sure that the executor function is the last one to be code generated so that all the
393398
// symbols are available to tvm_run_func
394399
auto fun_name = std::string(kv.first->name_hint);
395-
const bool is_aot_executor_fn =
396-
(fun_name.rfind(::tvm::runtime::symbol::tvm_run_func_suffix) != std::string::npos);
400+
bool is_aot_executor_fn = kv.second->GetAttr<Bool>("runner_function", Bool(false)).value();
397401

398402
if (is_aot_executor_fn) {
399403
aot_executor_fn = Downcast<PrimFunc>(kv.second);

tests/python/relay/aot/aot_test.mk

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
# Makefile to build ethosu_test_runner
1817
# Setup build environment
1918
#
2019
AOT_ROOT ?= $(TVM_ROOT)/src/runtime/crt/aot

tests/python/relay/aot/aot_test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def create_main(test_name, input_list_map, output_list_map, output_path):
180180
def create_header_file(tensor_name, npy_data, output_path):
181181
"""
182182
This method generates a header file containing the data contained in the numpy array provided.
183-
It is used to capture the tensor data (for both inputs and expected outputs) to be bundled into the standalone ethosu_test_runner.
183+
It is used to capture the tensor data (for both inputs and expected outputs) to be bundled into the standalone application.
184184
"""
185185
file_path = pathlib.Path(f"{output_path}/" + tensor_name).resolve()
186186
# create header file

tests/python/relay/aot/test_crt_aot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,9 @@ def test_byoc_utvm():
329329
mod = tvm.IRModule()
330330
ann = CcompilerAnnotator()
331331
mod["main"] = ann.visit(f)
332-
mod = tvm.relay.transform.PartitionGraphWithModName("my_mod")(mod)
332+
333+
mod = tvm.relay.transform.PartitionGraph()(mod)
334+
mod = tvm.relay.transform.MangleGraph("my_mod")(mod)
333335
mod = tvm.relay.transform.InferType()(mod)
334336

335337
x_data = np.random.rand(10, 10).astype("float32")

0 commit comments

Comments
 (0)