Skip to content

Commit d1df9c8

Browse files
committed
Introduce --interface-api={c,packed} parameter
This introduces structures generated to provide a documented and stable user friendly interface to a TVM generated model, as can be seen in the AOT demo application: ``` struct tvmgen_default_inputs inputs = { .input_1 = input_data, }; struct tvmgen_default_outputs outputs = { .output = output_data, }; int ret_val = tvmgen_default_run(&inputs, &outputs, NULL, NULL); ``` To facilitate this, some other changes are included: * Removed dependency on `aot_executor.{c,h}` in tests, pending the discussion in the interface RFC as to whether we keep them. * Moved creation of test DLTensor's into the AOT test utils, in future this can be replaced by loading via the Python API or otherwise * Introduce `parametrize_aot_options` which can be used to test permutations of AOT which work together - for now this filters C interface and packed operators * Updated demo application to generate the header for demonstration purposes, we should consider porting the demo application to Model Library Format and using the toolchain in the Zephyr App via CMake instead? This patch builds upon the improvements @giuseros made to AOT testing and name mangling from apache#8014
1 parent ab01abc commit d1df9c8

File tree

13 files changed

+611
-245
lines changed

13 files changed

+611
-245
lines changed

apps/microtvm/zephyr/aot_demo/src/main.c

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
#include "input_data.h"
3434
#include "output_data.h"
35+
#include "tvmgen_default.h"
3536
#include "zephyr_uart.h"
3637

3738
#ifdef CONFIG_ARCH_POSIX
@@ -194,18 +195,18 @@ void main(void) {
194195
}
195196
TVMLogf("Zephyr AOT Runtime\n");
196197

197-
void* inputs[1] = {
198-
input_data,
198+
struct tvmgen_default_inputs inputs = {
199+
.input_1 = input_data,
199200
};
200-
void* outputs[1] = {
201-
output_data,
201+
struct tvmgen_default_outputs outputs = {
202+
.output = output_data,
202203
};
203204

204205
StackMemoryManager_Init(&app_workspace, g_aot_memory, WORKSPACE_SIZE);
205206

206207
double elapsed_time = 0;
207208
TVMPlatformTimerStart();
208-
int ret_val = tvm_runtime_run(&tvmgen_default_network, inputs, outputs);
209+
int ret_val = tvmgen_default_run(&inputs, &outputs, NULL, NULL);
209210
TVMPlatformTimerStop(&elapsed_time);
210211

211212
if (ret_val != 0) {

include/tvm/runtime/module.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ constexpr const char* tvm_param_prefix = "__tvm_param__";
232232
constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param";
233233
/*! \brief The main AOT executor function */
234234
constexpr const char* tvm_run_func_suffix = "run_model";
235+
/*! \brief The models entrypoint function which calls the executor */
236+
constexpr const char* tvm_entrypoint_suffix = "run";
235237
} // namespace symbol
236238

237239
// implementations of inline functions.

python/tvm/micro/interface_api.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Defines functions for generating a C interface header"""
19+
20+
import os
21+
22+
from tvm.relay.backend.utils import mangle_module_name
23+
24+
25+
def _emit_brief(header_file, model_name, description):
26+
header_file.write("/*!\n")
27+
header_file.write(f" * \\brief TVM {model_name} model {description} \n")
28+
header_file.write(" */\n")
29+
30+
31+
def generate_c_interface_header(model_name, inputs, outputs, output_path):
32+
"""Generates a C interface header for a given models inputs and outputs
33+
34+
Parameters
35+
----------
36+
model_name : str
37+
Name of the model to be used in defining structs and naming the header
38+
inputs : list[str]
39+
List of model input names to be placed in generated structs
40+
outputs : list[str]
41+
List of model output names to be placed in generated structs
42+
output_path : str
43+
Path to the output folder to generate the header into
44+
"""
45+
46+
mangled_name = mangle_module_name(model_name)
47+
metadata_header = os.path.join(output_path, f"{mangled_name}.h")
48+
with open(metadata_header, "w") as header_file:
49+
_emit_brief(header_file, model_name, "input tensors")
50+
header_file.write(f"struct {mangled_name}_inputs {{\n")
51+
for input_name in inputs:
52+
header_file.write(f" void* {input_name};\n")
53+
header_file.write("};\n\n")
54+
55+
_emit_brief(header_file, model_name, "output tensors")
56+
header_file.write(f"struct {mangled_name}_outputs {{\n")
57+
for output_name in outputs:
58+
header_file.write(f" void* {output_name};\n")
59+
header_file.write("};\n\n")
60+
61+
_emit_brief(header_file, model_name, "memory blocks")
62+
header_file.write(f"struct {mangled_name}_memory {{\n")
63+
header_file.write("};\n\n")
64+
65+
_emit_brief(header_file, model_name, "device configurations")
66+
header_file.write(f"struct {mangled_name}_devices {{\n")
67+
header_file.write("};\n\n")
68+
69+
header_file.write("/*!\n")
70+
header_file.write(f" * \\brief TVM {model_name} model run function \n")
71+
header_file.write(" * \\param inputs Input tensors for the model \n")
72+
header_file.write(" * \\param outputs Output tensors for the model \n")
73+
header_file.write(" * \\param memory Memory blocks for the model to use \n")
74+
header_file.write(" * \\param devices Devices for the model to use \n")
75+
header_file.write(" */\n")
76+
header_file.write(f"int {mangled_name}_run(\n")
77+
header_file.write(f" struct {mangled_name}_inputs* inputs,\n")
78+
header_file.write(f" struct {mangled_name}_outputs* outputs,\n")
79+
header_file.write(f" struct {mangled_name}_memory* memory,\n")
80+
header_file.write(f" struct {mangled_name}_devices* devices\n")
81+
header_file.write(");\n")

python/tvm/micro/model_library_format.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import re
2424
import tarfile
2525

26+
from tvm.ir.type import TupleType
27+
from .interface_api import generate_c_interface_header
2628
from ..contrib import utils
2729
from ..relay.backend import executor_factory
2830
from ..relay import param_dict
@@ -49,7 +51,6 @@ def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None):
4951
5052
"""
5153
dso_modules = mod._collect_dso_modules()
52-
dso_module_handles = [m.handle.value for m in dso_modules]
5354
non_dso_modules = mod._collect_from_import_tree(lambda m: m not in dso_modules)
5455
if non_dso_modules:
5556
raise UnsupportedInModelLibraryFormatError(
@@ -207,6 +208,42 @@ def _build_function_memory_map(function_metadata):
207208
return ret
208209

209210

211+
def _get_main_relay_func(mod: executor_factory.ExecutorFactoryModule):
212+
main_func = mod.function_metadata[MAIN_FUNC_NAME_STR]
213+
target = list(main_func.relay_primfuncs.keys())[0]
214+
return main_func.relay_primfuncs[target]
215+
216+
217+
def _convert_tuple_to_outputs(ret_type, offset=0):
218+
outputs = []
219+
added_fields = len(ret_type.fields)
220+
for output_index in range(added_fields):
221+
next_output = offset + len(outputs)
222+
if isinstance(ret_type.fields[output_index], TupleType):
223+
outputs.extend(_convert_tuple_to_outputs(ret_type.fields[output_index], next_output))
224+
else:
225+
outputs.append(f"output{next_output}")
226+
return outputs
227+
228+
229+
def _get_inputs_and_outputs_from_module(mod):
230+
main_func = _get_main_relay_func(mod)
231+
inputs = [argument.name_hint for argument in main_func.params]
232+
233+
outputs = ["output"]
234+
if isinstance(main_func.ret_type, TupleType):
235+
outputs = _convert_tuple_to_outputs(main_func.ret_type)
236+
237+
return inputs, outputs
238+
239+
240+
def _should_generate_interface_header(mod):
241+
for _, target in mod.target.items():
242+
if "interface-api" in target.attrs and target.attrs["interface-api"] == "c":
243+
return True
244+
return False
245+
246+
210247
def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, file_name):
211248
"""Export the build artifact in Model Library Format.
212249
@@ -246,6 +283,12 @@ def export_model_library_format(mod: executor_factory.ExecutorFactoryModule, fil
246283
os.mkdir(codegen_dir_path)
247284
_populate_codegen_dir(mod.lib, codegen_dir_path, mod.libmod_name)
248285

286+
if _should_generate_interface_header(mod):
287+
include_path = os.path.join(codegen_dir_path, "host/include")
288+
os.mkdir(include_path)
289+
inputs, outputs = _get_inputs_and_outputs_from_module(mod)
290+
generate_c_interface_header(mod.libmod_name, inputs, outputs, include_path)
291+
249292
parameters_dir_path = tempdir.relpath("parameters")
250293
os.mkdir(parameters_dir_path)
251294
param_filename = os.path.join(parameters_dir_path, f"{mod.libmod_name}.params")

src/relay/backend/aot_executor_codegen.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ class AOTExecutorCodegen : public ExprVisitor {
652652
/*! \brief mod */
653653
runtime::Module* mod_;
654654
/*! \brief list of input expressions (i.e., variable passed by the user) */
655-
std::vector<Expr> input_vars_;
655+
std::vector<Var> input_vars_;
656656
/*! \brief input and output variables belonging to the main function signature */
657657
Array<tir::Var> main_signature_;
658658
/*! \brief target device */
@@ -783,8 +783,8 @@ class AOTExecutorCodegen : public ExprVisitor {
783783
ret.lowered_funcs.Set(target_host_str, mod_run);
784784
}
785785
ret.function_metadata = std::move(function_metadata_);
786-
ret.metadata = runtime::Metadata(input_vars_.size(), return_sid_.size(),
787-
runtime::kTvmExecutorAot, mod_name);
786+
787+
ret.metadata = runtime::Metadata(input_vars_, return_sid_.size(), runtime::kTvmExecutorAot, mod_name);
788788
return ret;
789789
}
790790
};

src/runtime/meta_data.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
#include <dmlc/io.h>
2828
#include <dmlc/json.h>
29+
#include <tvm/relay/expr.h>
2930
#include <tvm/runtime/executor_info.h>
3031
#include <tvm/runtime/module.h>
3132
#include <tvm/runtime/ndarray.h>
@@ -54,8 +55,8 @@ inline String get_name_mangled(const String& module_name, const String& name) {
5455
*/
5556
class MetadataNode : public Object {
5657
public:
57-
/*! \brief number of inputs of the main function */
58-
int num_inputs = 1;
58+
/*! \brief input information for the main function */
59+
Array<tvm::relay::Var> inputs;
5960
/*! \brief number of outputs of the main function */
6061
int num_outputs = 1;
6162
/*! \brief the executor to be used to run the model */
@@ -73,9 +74,9 @@ class MetadataNode : public Object {
7374
*/
7475
class Metadata : public ObjectRef {
7576
public:
76-
TVM_DLL Metadata(int num_inputs, int num_outputs, String executor, String mod_name) {
77+
TVM_DLL Metadata(Array<tvm::relay::Var> inputs, int num_outputs, String executor, String mod_name) {
7778
auto n = make_object<MetadataNode>();
78-
n->num_inputs = num_inputs;
79+
n->inputs = inputs;
7980
n->num_outputs = num_outputs;
8081
n->executor = executor;
8182
n->mod_name = mod_name;

src/target/source/source_module.cc

Lines changed: 70 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -192,25 +192,26 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
192192
<< "}\n";
193193
}
194194

195-
void GenerateEntrypointForUnpackedAPI(const std::string& run_func) {
195+
void GenerateEntrypointForUnpackedAPI(const std::string& entrypoint_name,
196+
const std::string& run_func) {
196197
code_ << "TVM_DLL int32_t " << run_func << "(";
197-
int total_args = (metadata_->num_inputs + metadata_->num_outputs);
198-
for (int i = 0; i < total_args; ++i) {
199-
code_ << "arg" << i;
198+
unsigned int total_args = (metadata_->inputs.size() + metadata_->num_outputs);
199+
for (unsigned int i = 0; i < total_args; ++i) {
200+
code_ << "void* arg" << i;
200201
if (i + 1 != total_args) {
201202
code_ << ",";
202203
}
203204
}
204205
code_ << ");\n";
205-
code_ << "static int32_t " << ::tvm::runtime::symbol::tvm_module_main;
206+
code_ << "int32_t " << entrypoint_name;
206207
code_ << "(void* args, void* type_code, int num_args, void* out_value, void* "
207208
"out_type_code, void* resource_handle) {\n";
208209
code_ << "return " << run_func << "(";
209-
for (int i = 0; i < metadata_->num_inputs; ++i) {
210+
for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) {
210211
code_ << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,";
211212
}
212213
for (int i = 0; i < metadata_->num_outputs; ++i) {
213-
int j = metadata_->num_inputs + i;
214+
int j = metadata_->inputs.size() + i;
214215
code_ << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data";
215216
if (i + 1 != metadata_->num_outputs) {
216217
code_ << ",";
@@ -220,37 +221,85 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
220221
code_ << "}\n";
221222
}
222223

223-
void GenerateEntrypointForPackedAPI(const std::string& run_func) {
224+
void GenerateEntrypointForPackedAPI(const std::string& entrypoint_name,
225+
const std::string& run_func) {
224226
code_ << "TVM_DLL int32_t " << run_func;
225227
code_ << "(void* args, void* type_code, int num_args, void* out_value, void* "
226228
"out_type_code, void* resource_handle);\n";
227-
code_ << "static int32_t " << ::tvm::runtime::symbol::tvm_module_main;
229+
code_ << "int32_t " << entrypoint_name;
228230
code_ << "(void* args, void* type_code, int num_args, void* out_value, void* "
229231
"out_type_code, void* resource_handle) {\n";
230232
code_ << "return " << run_func;
231233
code_ << "(args, type_code, num_args, out_value, out_type_code, resource_handle);\n";
232234
code_ << "}\n";
233235
}
234236

237+
void GenerateCInterfaceEntrypoint(const std::string& entrypoint_name, const std::string& run_func,
238+
const std::string& mod_name) {
239+
code_ << "#include <" << mod_name << ".h>\n";
240+
code_ << "TVM_DLL int32_t " << run_func << "(";
241+
unsigned int total_args = (metadata_->inputs.size() + metadata_->num_outputs);
242+
for (unsigned int i = 0; i < total_args; ++i) {
243+
code_ << "void* arg" << i;
244+
if (i + 1 != total_args) {
245+
code_ << ",";
246+
}
247+
}
248+
code_ << ");\n";
249+
code_ << "int32_t " << entrypoint_name << "(";
250+
code_ << "struct " << runtime::get_name_mangled(mod_name, "inputs") << "* inputs,"
251+
<< "struct " << runtime::get_name_mangled(mod_name, "outputs") << "* outputs,"
252+
<< "struct " << runtime::get_name_mangled(mod_name, "memory") << "* memory,"
253+
<< "struct " << runtime::get_name_mangled(mod_name, "devices") << "* devices"
254+
<< ") {";
255+
code_ << "return " << run_func << "(";
256+
for (const auto& input : metadata_->inputs) {
257+
code_ << "inputs->" << input->name_hint() << ",";
258+
}
259+
if (metadata_->num_outputs == 1) {
260+
code_ << "outputs->output";
261+
} else {
262+
for (int i = 0; i < metadata_->num_outputs; ++i) {
263+
code_ << "outputs->output" << i;
264+
if (i + 1 != metadata_->num_outputs) {
265+
code_ << ",";
266+
}
267+
}
268+
}
269+
code_ << ");\n";
270+
code_ << "}\n";
271+
}
272+
235273
void GenerateAOTDescriptor() {
236-
const std::string run_func = ::tvm::runtime::symbol::tvm_run_func_suffix;
237-
const std::string run_func_mangled = runtime::get_name_mangled(metadata_->mod_name, run_func);
274+
const std::string run_func_suffix = ::tvm::runtime::symbol::tvm_run_func_suffix;
275+
const std::string tvm_entrypoint_suffix = ::tvm::runtime::symbol::tvm_entrypoint_suffix;
276+
const std::string run_func_mangled =
277+
runtime::get_name_mangled(metadata_->mod_name, run_func_suffix);
278+
const std::string entrypoint_mangled =
279+
runtime::get_name_mangled(metadata_->mod_name, tvm_entrypoint_suffix);
238280
const std::string network_mangled = runtime::get_name_mangled(metadata_->mod_name, "network");
239-
code_ << "#include \"tvm/runtime/crt/internal/aot_executor/aot_executor.h\"\n";
281+
auto unpacked_api = target_->GetAttr<Bool>("unpacked-api").value_or(Bool(false));
282+
auto interface_api = target_->GetAttr<String>("interface-api").value_or(String("packed"));
283+
240284
code_ << "#include \"tvm/runtime/c_runtime_api.h\"\n";
241285
code_ << "#ifdef __cplusplus\n";
242-
code_ << "extern \"C\"\n";
286+
code_ << "extern \"C\" {\n";
243287
code_ << "#endif\n";
244-
if (target_->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
245-
GenerateEntrypointForUnpackedAPI(run_func_mangled);
288+
289+
if (unpacked_api) {
290+
if (interface_api == "c") {
291+
GenerateCInterfaceEntrypoint(entrypoint_mangled, run_func_mangled, metadata_->mod_name);
292+
} else {
293+
GenerateEntrypointForUnpackedAPI(entrypoint_mangled, run_func_mangled);
294+
}
246295
} else {
247-
GenerateEntrypointForPackedAPI(run_func_mangled);
296+
ICHECK_EQ(interface_api, "packed") << "Packed interface required for packed operators";
297+
GenerateEntrypointForPackedAPI(entrypoint_mangled, run_func_mangled);
248298
}
249-
code_ << "const tvm_model_t " << network_mangled << " = {\n"
250-
<< " .run_func = &" << ::tvm::runtime::symbol::tvm_module_main << ",\n"
251-
<< " .num_input_tensors = " << metadata_->num_inputs << ",\n"
252-
<< " .num_output_tensors = " << metadata_->num_outputs << ", \n"
253-
<< "};\n";
299+
300+
code_ << "#ifdef __cplusplus\n";
301+
code_ << "}\n";
302+
code_ << "#endif\n";
254303
}
255304

256305
void CreateSource() {

src/target/target_kind.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
299299
.add_attr_option<String>("runtime")
300300
.add_attr_option<Bool>("link-params", Bool(false))
301301
.add_attr_option<Bool>("unpacked-api")
302+
.add_attr_option<String>("interface-api")
302303
.set_default_keys({"cpu"});
303304

304305
TVM_REGISTER_TARGET_KIND("c", kDLCPU)
@@ -310,6 +311,7 @@ TVM_REGISTER_TARGET_KIND("c", kDLCPU)
310311
.add_attr_option<String>("executor")
311312
.add_attr_option<Integer>("workspace-byte-alignment")
312313
.add_attr_option<Bool>("unpacked-api")
314+
.add_attr_option<String>("interface-api")
313315
.set_default_keys({"cpu"});
314316

315317
TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)

0 commit comments

Comments
 (0)